From 0e8516a12b51d22151a44fc2a139e16067def2db Mon Sep 17 00:00:00 2001 From: Rohan Agarwal Date: Tue, 17 Dec 2024 14:01:22 -0800 Subject: [PATCH 1/5] Add prompt caching --- requirements-constraints.txt | 2 +- requirements.txt | 6 ++---- src/seer/automation/agent/client.py | 13 +++++++++++++ 3 files changed, 16 insertions(+), 5 deletions(-) diff --git a/requirements-constraints.txt b/requirements-constraints.txt index 4fff18586..3db3375b5 100644 --- a/requirements-constraints.txt +++ b/requirements-constraints.txt @@ -99,7 +99,7 @@ chromadb==0.4.14 google-cloud-storage==2.* google-cloud-aiplatform==1.* google-cloud-secret-manager==2.* -anthropic[vertex]==0.34.2 +anthropic[vertex]==0.41.0 langfuse @ git+https://github.com/jennmueng/langfuse-python.git@9d9350de1e4e84fa548fe84f82c1b826be17956e watchdog stumpy==1.13.0 diff --git a/requirements.txt b/requirements.txt index 44a6aaaf0..8b0248eee 100644 --- a/requirements.txt +++ b/requirements.txt @@ -21,7 +21,7 @@ amqp==5.3.1 # via kombu annotated-types==0.7.0 # via pydantic -anthropic==0.34.2 +anthropic==0.41.0 # via -r requirements-constraints.txt anyio==4.7.0 # via @@ -268,7 +268,7 @@ httpcore==1.0.7 # via httpx httptools==0.6.4 # via uvicorn -httpx==0.27.2 +httpx==0.28.1 # via # -r requirements-constraints.txt # anthropic @@ -700,7 +700,6 @@ sniffio==1.3.1 # via # anthropic # anyio - # httpx # openai sqlalchemy==2.0.25 # via @@ -730,7 +729,6 @@ threadpoolctl==3.2.0 # scikit-learn tokenizers==0.15.2 # via - # anthropic # chromadb # transformers torch==2.2.0 diff --git a/src/seer/automation/agent/client.py b/src/seer/automation/agent/client.py index 3f4ae9d55..3e26d65f9 100644 --- a/src/seer/automation/agent/client.py +++ b/src/seer/automation/agent/client.py @@ -7,6 +7,7 @@ import anthropic from anthropic import NOT_GIVEN from anthropic.types import ( + CacheControlEphemeralParam, MessageParam, TextBlockParam, ToolParam, @@ -542,10 +543,22 @@ def _prep_message_and_tools( message_dicts = [cls.to_message_param(message) for message in messages] if messages else [] if prompt: message_dicts.append(cls.to_message_param(Message(role="user", content=prompt))) + # Set caching breakpoints for the last message and the 3rd last message + if len(message_dicts) > 0 and message_dicts[-1]["content"]: + message_dicts[-1]["content"][0]["cache_control"] = CacheControlEphemeralParam( + type="ephemeral" + ) + if len(message_dicts) >= 4 and message_dicts[-4]["content"]: + message_dicts[-4]["content"][0]["cache_control"] = CacheControlEphemeralParam( + type="ephemeral" + ) tool_dicts = ( [cls.to_tool_dict(tool) for tool in tools] if tools and len(tools) > 0 else None ) + # set caching breakpoint at end of tools + if tool_dicts: + tool_dicts[-1]["cache_control"] = CacheControlEphemeralParam(type="ephemeral") return message_dicts, tool_dicts, system_prompt From 64c93acfeb2f4b7b70de9c671866f4ca1a85c981 Mon Sep 17 00:00:00 2001 From: Rohan Agarwal Date: Wed, 18 Dec 2024 14:36:15 -0800 Subject: [PATCH 2/5] change pattern --- src/seer/automation/agent/client.py | 29 ++++++++++++++++++++--------- 1 file changed, 20 insertions(+), 9 deletions(-) diff --git a/src/seer/automation/agent/client.py b/src/seer/automation/agent/client.py index 3e26d65f9..3aeb9c9f0 100644 --- a/src/seer/automation/agent/client.py +++ b/src/seer/automation/agent/client.py @@ -416,7 +416,7 @@ def generate_text( max_tokens: int | None = None, timeout: float | None = None, ): - message_dicts, tool_dicts, system_prompt = self._prep_message_and_tools( + message_dicts, tool_dicts, system_message = self._prep_message_and_tools( messages=messages, prompt=prompt, system_prompt=system_prompt, @@ -426,7 +426,7 @@ def generate_text( anthropic_client = self.get_client() completion = anthropic_client.messages.create( - system=system_prompt or NOT_GIVEN, + system=system_message or NOT_GIVEN, model=self.model_name, tools=cast(Iterable[ToolParam], tool_dicts) if tool_dicts else NOT_GIVEN, messages=cast(Iterable[MessageParam], message_dicts), @@ -539,7 +539,7 @@ def _prep_message_and_tools( prompt: str | None = None, system_prompt: str | None = None, tools: list[FunctionTool] | None = None, - ) -> tuple[list[MessageParam], list[ToolParam] | None, str | None]: + ) -> tuple[list[MessageParam], list[ToolParam] | None, list[TextBlockParam] | None]: message_dicts = [cls.to_message_param(message) for message in messages] if messages else [] if prompt: message_dicts.append(cls.to_message_param(Message(role="user", content=prompt))) @@ -548,19 +548,30 @@ def _prep_message_and_tools( message_dicts[-1]["content"][0]["cache_control"] = CacheControlEphemeralParam( type="ephemeral" ) - if len(message_dicts) >= 4 and message_dicts[-4]["content"]: - message_dicts[-4]["content"][0]["cache_control"] = CacheControlEphemeralParam( + if len(message_dicts) >= 3 and message_dicts[-3]["content"]: + message_dicts[-3]["content"][0]["cache_control"] = CacheControlEphemeralParam( type="ephemeral" ) tool_dicts = ( [cls.to_tool_dict(tool) for tool in tools] if tools and len(tools) > 0 else None ) - # set caching breakpoint at end of tools if tool_dicts: tool_dicts[-1]["cache_control"] = CacheControlEphemeralParam(type="ephemeral") - return message_dicts, tool_dicts, system_prompt + system_message = ( + [ + TextBlockParam( + type="text", + text=system_prompt, + cache_control=CacheControlEphemeralParam(type="ephemeral"), + ) + ] + if system_prompt + else [] + ) + + return message_dicts, tool_dicts, system_message @observe(as_type="generation", name="Anthropic Stream") def generate_text_stream( @@ -574,7 +585,7 @@ def generate_text_stream( max_tokens: int | None = None, timeout: float | None = None, ) -> Iterator[str | ToolCall | Usage]: - message_dicts, tool_dicts, system_prompt = self._prep_message_and_tools( + message_dicts, tool_dicts, system_message = self._prep_message_and_tools( messages=messages, prompt=prompt, system_prompt=system_prompt, @@ -584,7 +595,7 @@ def generate_text_stream( anthropic_client = self.get_client() stream = anthropic_client.messages.create( - system=system_prompt or NOT_GIVEN, + system=system_message or NOT_GIVEN, model=self.model_name, tools=cast(Iterable[ToolParam], tool_dicts) if tool_dicts else NOT_GIVEN, messages=cast(Iterable[MessageParam], message_dicts), From 0dafb39f99f38f5527e833b1f3f4711cb4f0e22d Mon Sep 17 00:00:00 2001 From: Rohan Agarwal Date: Wed, 18 Dec 2024 16:40:13 -0800 Subject: [PATCH 3/5] Update --- requirements-constraints.txt | 2 +- requirements.txt | 15 ++++++++------- tests/automation/agent/test_client.py | 2 +- 3 files changed, 10 insertions(+), 9 deletions(-) diff --git a/requirements-constraints.txt b/requirements-constraints.txt index 3db3375b5..798949641 100644 --- a/requirements-constraints.txt +++ b/requirements-constraints.txt @@ -99,7 +99,7 @@ chromadb==0.4.14 google-cloud-storage==2.* google-cloud-aiplatform==1.* google-cloud-secret-manager==2.* -anthropic[vertex]==0.41.0 +anthropic[vertex]==0.42.0 langfuse @ git+https://github.com/jennmueng/langfuse-python.git@9d9350de1e4e84fa548fe84f82c1b826be17956e watchdog stumpy==1.13.0 diff --git a/requirements.txt b/requirements.txt index 8b0248eee..73332bedb 100644 --- a/requirements.txt +++ b/requirements.txt @@ -6,7 +6,7 @@ # aiohappyeyeballs==2.4.4 # via aiohttp -aiohttp==3.11.10 +aiohttp==3.11.11 # via # -r requirements-constraints.txt # datasets @@ -21,7 +21,7 @@ amqp==5.3.1 # via kombu annotated-types==0.7.0 # via pydantic -anthropic==0.41.0 +anthropic==0.42.0 # via -r requirements-constraints.txt anyio==4.7.0 # via @@ -200,7 +200,7 @@ google-auth==2.37.0 # google-cloud-secret-manager # google-cloud-storage # google-genai -google-cloud-aiplatform==1.74.0 +google-cloud-aiplatform==1.75.0 # via -r requirements-constraints.txt google-cloud-bigquery==3.27.0 # via google-cloud-aiplatform @@ -220,7 +220,7 @@ google-crc32c==1.6.0 # via # google-cloud-storage # google-resumable-media -google-genai==0.2.2 +google-genai==0.3.0 # via -r requirements-constraints.txt google-resumable-media==2.7.2 # via @@ -268,7 +268,7 @@ httpcore==1.0.7 # via httpx httptools==0.6.4 # via uvicorn -httpx==0.28.1 +httpx==0.27.2 # via # -r requirements-constraints.txt # anthropic @@ -426,7 +426,7 @@ onnx==1.16.0 # via -r requirements-constraints.txt onnxruntime==1.20.1 # via chromadb -openai==1.57.4 +openai==1.58.1 # via -r requirements-constraints.txt openapi-core==0.18.2 # via -r requirements-constraints.txt @@ -505,7 +505,7 @@ proto-plus==1.25.0 # google-cloud-aiplatform # google-cloud-resource-manager # google-cloud-secret-manager -protobuf==5.29.1 +protobuf==5.29.2 # via # -r requirements-constraints.txt # google-api-core @@ -700,6 +700,7 @@ sniffio==1.3.1 # via # anthropic # anyio + # httpx # openai sqlalchemy==2.0.25 # via diff --git a/tests/automation/agent/test_client.py b/tests/automation/agent/test_client.py index f4bf9ce2d..9bcb1639d 100644 --- a/tests/automation/agent/test_client.py +++ b/tests/automation/agent/test_client.py @@ -301,7 +301,7 @@ def test_anthropic_prep_message_and_tools(): assert "description" in tool_dicts[0] assert "input_schema" in tool_dicts[0] - assert returned_system_prompt == system_prompt + assert returned_system_prompt[0]["text"] == system_prompt @pytest.mark.vcr() From 57e26227b86bb21fab09d2b93c70a7f45c206bb7 Mon Sep 17 00:00:00 2001 From: Rohan Agarwal Date: Wed, 18 Dec 2024 16:51:49 -0800 Subject: [PATCH 4/5] Fix --- src/seer/automation/agent/client.py | 8 ++------ 1 file changed, 2 insertions(+), 6 deletions(-) diff --git a/src/seer/automation/agent/client.py b/src/seer/automation/agent/client.py index 3aeb9c9f0..20b8d35b1 100644 --- a/src/seer/automation/agent/client.py +++ b/src/seer/automation/agent/client.py @@ -545,13 +545,9 @@ def _prep_message_and_tools( message_dicts.append(cls.to_message_param(Message(role="user", content=prompt))) # Set caching breakpoints for the last message and the 3rd last message if len(message_dicts) > 0 and message_dicts[-1]["content"]: - message_dicts[-1]["content"][0]["cache_control"] = CacheControlEphemeralParam( - type="ephemeral" - ) + message_dicts[-1]["cache_control"] = CacheControlEphemeralParam(type="ephemeral") if len(message_dicts) >= 3 and message_dicts[-3]["content"]: - message_dicts[-3]["content"][0]["cache_control"] = CacheControlEphemeralParam( - type="ephemeral" - ) + message_dicts[-3]["cache_control"] = CacheControlEphemeralParam(type="ephemeral") tool_dicts = ( [cls.to_tool_dict(tool) for tool in tools] if tools and len(tools) > 0 else None From 0a5f984f0e6b29314af18a8494bf8603f5272e09 Mon Sep 17 00:00:00 2001 From: Rohan Agarwal Date: Wed, 18 Dec 2024 16:54:08 -0800 Subject: [PATCH 5/5] Fix --- src/seer/automation/agent/client.py | 14 +++++++++----- 1 file changed, 9 insertions(+), 5 deletions(-) diff --git a/src/seer/automation/agent/client.py b/src/seer/automation/agent/client.py index 20b8d35b1..f3e730988 100644 --- a/src/seer/automation/agent/client.py +++ b/src/seer/automation/agent/client.py @@ -544,16 +544,20 @@ def _prep_message_and_tools( if prompt: message_dicts.append(cls.to_message_param(Message(role="user", content=prompt))) # Set caching breakpoints for the last message and the 3rd last message - if len(message_dicts) > 0 and message_dicts[-1]["content"]: - message_dicts[-1]["cache_control"] = CacheControlEphemeralParam(type="ephemeral") - if len(message_dicts) >= 3 and message_dicts[-3]["content"]: - message_dicts[-3]["cache_control"] = CacheControlEphemeralParam(type="ephemeral") + if len(message_dicts) > 0 and message_dicts[-1].content: + message_dicts[-1].content[0].cache_control = CacheControlEphemeralParam( + type="ephemeral" + ) + if len(message_dicts) >= 3 and message_dicts[-3]: + message_dicts[-3].content[0].cache_control = CacheControlEphemeralParam( + type="ephemeral" + ) tool_dicts = ( [cls.to_tool_dict(tool) for tool in tools] if tools and len(tools) > 0 else None ) if tool_dicts: - tool_dicts[-1]["cache_control"] = CacheControlEphemeralParam(type="ephemeral") + tool_dicts[-1].cache_control = CacheControlEphemeralParam(type="ephemeral") system_message = ( [