@@ -416,7 +416,7 @@ def generate_text(
416
416
max_tokens : int | None = None ,
417
417
timeout : float | None = None ,
418
418
):
419
- message_dicts , tool_dicts , system_prompt = self ._prep_message_and_tools (
419
+ message_dicts , tool_dicts , system_message = self ._prep_message_and_tools (
420
420
messages = messages ,
421
421
prompt = prompt ,
422
422
system_prompt = system_prompt ,
@@ -426,7 +426,7 @@ def generate_text(
426
426
anthropic_client = self .get_client ()
427
427
428
428
completion = anthropic_client .messages .create (
429
- system = system_prompt or NOT_GIVEN ,
429
+ system = system_message or NOT_GIVEN ,
430
430
model = self .model_name ,
431
431
tools = cast (Iterable [ToolParam ], tool_dicts ) if tool_dicts else NOT_GIVEN ,
432
432
messages = cast (Iterable [MessageParam ], message_dicts ),
@@ -539,7 +539,7 @@ def _prep_message_and_tools(
539
539
prompt : str | None = None ,
540
540
system_prompt : str | None = None ,
541
541
tools : list [FunctionTool ] | None = None ,
542
- ) -> tuple [list [MessageParam ], list [ToolParam ] | None , str | None ]:
542
+ ) -> tuple [list [MessageParam ], list [ToolParam ] | None , list [ TextBlockParam ] | None ]:
543
543
message_dicts = [cls .to_message_param (message ) for message in messages ] if messages else []
544
544
if prompt :
545
545
message_dicts .append (cls .to_message_param (Message (role = "user" , content = prompt )))
@@ -548,19 +548,30 @@ def _prep_message_and_tools(
548
548
message_dicts [- 1 ]["content" ][0 ]["cache_control" ] = CacheControlEphemeralParam (
549
549
type = "ephemeral"
550
550
)
551
- if len (message_dicts ) >= 4 and message_dicts [- 4 ]["content" ]:
552
- message_dicts [- 4 ]["content" ][0 ]["cache_control" ] = CacheControlEphemeralParam (
551
+ if len (message_dicts ) >= 3 and message_dicts [- 3 ]["content" ]:
552
+ message_dicts [- 3 ]["content" ][0 ]["cache_control" ] = CacheControlEphemeralParam (
553
553
type = "ephemeral"
554
554
)
555
555
556
556
tool_dicts = (
557
557
[cls .to_tool_dict (tool ) for tool in tools ] if tools and len (tools ) > 0 else None
558
558
)
559
- # set caching breakpoint at end of tools
560
559
if tool_dicts :
561
560
tool_dicts [- 1 ]["cache_control" ] = CacheControlEphemeralParam (type = "ephemeral" )
562
561
563
- return message_dicts , tool_dicts , system_prompt
562
+ system_message = (
563
+ [
564
+ TextBlockParam (
565
+ type = "text" ,
566
+ text = system_prompt ,
567
+ cache_control = CacheControlEphemeralParam (type = "ephemeral" ),
568
+ )
569
+ ]
570
+ if system_prompt
571
+ else []
572
+ )
573
+
574
+ return message_dicts , tool_dicts , system_message
564
575
565
576
@observe (as_type = "generation" , name = "Anthropic Stream" )
566
577
def generate_text_stream (
@@ -574,7 +585,7 @@ def generate_text_stream(
574
585
max_tokens : int | None = None ,
575
586
timeout : float | None = None ,
576
587
) -> Iterator [str | ToolCall | Usage ]:
577
- message_dicts , tool_dicts , system_prompt = self ._prep_message_and_tools (
588
+ message_dicts , tool_dicts , system_message = self ._prep_message_and_tools (
578
589
messages = messages ,
579
590
prompt = prompt ,
580
591
system_prompt = system_prompt ,
@@ -584,7 +595,7 @@ def generate_text_stream(
584
595
anthropic_client = self .get_client ()
585
596
586
597
stream = anthropic_client .messages .create (
587
- system = system_prompt or NOT_GIVEN ,
598
+ system = system_message or NOT_GIVEN ,
588
599
model = self .model_name ,
589
600
tools = cast (Iterable [ToolParam ], tool_dicts ) if tool_dicts else NOT_GIVEN ,
590
601
messages = cast (Iterable [MessageParam ], message_dicts ),
0 commit comments