Skip to content

Commit 64c93ac

Browse files
committed
change pattern
1 parent 0e8516a commit 64c93ac

File tree

1 file changed

+20
-9
lines changed

1 file changed

+20
-9
lines changed

src/seer/automation/agent/client.py

+20-9
Original file line numberDiff line numberDiff line change
@@ -416,7 +416,7 @@ def generate_text(
416416
max_tokens: int | None = None,
417417
timeout: float | None = None,
418418
):
419-
message_dicts, tool_dicts, system_prompt = self._prep_message_and_tools(
419+
message_dicts, tool_dicts, system_message = self._prep_message_and_tools(
420420
messages=messages,
421421
prompt=prompt,
422422
system_prompt=system_prompt,
@@ -426,7 +426,7 @@ def generate_text(
426426
anthropic_client = self.get_client()
427427

428428
completion = anthropic_client.messages.create(
429-
system=system_prompt or NOT_GIVEN,
429+
system=system_message or NOT_GIVEN,
430430
model=self.model_name,
431431
tools=cast(Iterable[ToolParam], tool_dicts) if tool_dicts else NOT_GIVEN,
432432
messages=cast(Iterable[MessageParam], message_dicts),
@@ -539,7 +539,7 @@ def _prep_message_and_tools(
539539
prompt: str | None = None,
540540
system_prompt: str | None = None,
541541
tools: list[FunctionTool] | None = None,
542-
) -> tuple[list[MessageParam], list[ToolParam] | None, str | None]:
542+
) -> tuple[list[MessageParam], list[ToolParam] | None, list[TextBlockParam] | None]:
543543
message_dicts = [cls.to_message_param(message) for message in messages] if messages else []
544544
if prompt:
545545
message_dicts.append(cls.to_message_param(Message(role="user", content=prompt)))
@@ -548,19 +548,30 @@ def _prep_message_and_tools(
548548
message_dicts[-1]["content"][0]["cache_control"] = CacheControlEphemeralParam(
549549
type="ephemeral"
550550
)
551-
if len(message_dicts) >= 4 and message_dicts[-4]["content"]:
552-
message_dicts[-4]["content"][0]["cache_control"] = CacheControlEphemeralParam(
551+
if len(message_dicts) >= 3 and message_dicts[-3]["content"]:
552+
message_dicts[-3]["content"][0]["cache_control"] = CacheControlEphemeralParam(
553553
type="ephemeral"
554554
)
555555

556556
tool_dicts = (
557557
[cls.to_tool_dict(tool) for tool in tools] if tools and len(tools) > 0 else None
558558
)
559-
# set caching breakpoint at end of tools
560559
if tool_dicts:
561560
tool_dicts[-1]["cache_control"] = CacheControlEphemeralParam(type="ephemeral")
562561

563-
return message_dicts, tool_dicts, system_prompt
562+
system_message = (
563+
[
564+
TextBlockParam(
565+
type="text",
566+
text=system_prompt,
567+
cache_control=CacheControlEphemeralParam(type="ephemeral"),
568+
)
569+
]
570+
if system_prompt
571+
else []
572+
)
573+
574+
return message_dicts, tool_dicts, system_message
564575

565576
@observe(as_type="generation", name="Anthropic Stream")
566577
def generate_text_stream(
@@ -574,7 +585,7 @@ def generate_text_stream(
574585
max_tokens: int | None = None,
575586
timeout: float | None = None,
576587
) -> Iterator[str | ToolCall | Usage]:
577-
message_dicts, tool_dicts, system_prompt = self._prep_message_and_tools(
588+
message_dicts, tool_dicts, system_message = self._prep_message_and_tools(
578589
messages=messages,
579590
prompt=prompt,
580591
system_prompt=system_prompt,
@@ -584,7 +595,7 @@ def generate_text_stream(
584595
anthropic_client = self.get_client()
585596

586597
stream = anthropic_client.messages.create(
587-
system=system_prompt or NOT_GIVEN,
598+
system=system_message or NOT_GIVEN,
588599
model=self.model_name,
589600
tools=cast(Iterable[ToolParam], tool_dicts) if tool_dicts else NOT_GIVEN,
590601
messages=cast(Iterable[MessageParam], message_dicts),

0 commit comments

Comments
 (0)