Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -149,7 +149,10 @@ wandb/

# asdf tool versions
.tool-versions

# ruff
/.ruff_cache/
.ruff_cache/

*.pkl
*.bin
Expand Down
125 changes: 54 additions & 71 deletions libs/aws/langchain_aws/chat_models/bedrock.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,16 +98,15 @@ def _convert_one_message_to_text_llama3(message: BaseMessage) -> str:
)
elif isinstance(message, HumanMessage):
message_text = (
f"<|start_header_id|>user" f"<|end_header_id|>{message.content}<|eot_id|>"
f"<|start_header_id|>user<|end_header_id|>{message.content}<|eot_id|>"
)
elif isinstance(message, AIMessage):
message_text = (
f"<|start_header_id|>assistant"
f"<|end_header_id|>{message.content}<|eot_id|>"
f"<|start_header_id|>assistant<|end_header_id|>{message.content}<|eot_id|>"
)
elif isinstance(message, SystemMessage):
message_text = (
f"<|start_header_id|>system" f"<|end_header_id|>{message.content}<|eot_id|>"
f"<|start_header_id|>system<|end_header_id|>{message.content}<|eot_id|>"
)
else:
raise ValueError(f"Got unknown type {message}")
Expand All @@ -131,21 +130,15 @@ def _convert_one_message_to_text_llama4(message: BaseMessage) -> str:
f"<|header_start|>{message.role}<|header_end|>{message.content}<|eot|>"
)
elif isinstance(message, HumanMessage):
message_text = (
f"<|header_start|>user<|header_end|>{message.content}<|eot|>"
)
message_text = f"<|header_start|>user<|header_end|>{message.content}<|eot|>"
elif isinstance(message, AIMessage):
message_text = (
f"<|header_start|>assistant<|header_end|>{message.content}<|eot|>"
)
elif isinstance(message, SystemMessage):
message_text = (
f"<|header_start|>system<|header_end|>{message.content}<|eot|>"
)
message_text = f"<|header_start|>system<|header_end|>{message.content}<|eot|>"
elif isinstance(message, ToolMessage):
message_text = (
f"<|header_start|>ipython<|header_end|>{message.content}<|eom|>"
)
message_text = f"<|header_start|>ipython<|header_end|>{message.content}<|eom|>"
else:
raise ValueError(f"Got unknown type {message}")

Expand Down Expand Up @@ -197,7 +190,7 @@ def convert_messages_to_prompt_anthropic(
"""
if messages is None:
return ""

messages = messages.copy() # don't mutate the original list
if len(messages) > 0 and not isinstance(messages[-1], AIMessage):
messages.append(AIMessage(content=""))
Expand Down Expand Up @@ -234,21 +227,13 @@ def convert_messages_to_prompt_mistral(messages: List[BaseMessage]) -> str:

def _convert_one_message_to_text_deepseek(message: BaseMessage) -> str:
if isinstance(message, ChatMessage):
message_text = (
f"<|{message.role}|>{message.content}"
)
message_text = f"<|{message.role}|>{message.content}"
elif isinstance(message, HumanMessage):
message_text = (
f"<|User|>{message.content}"
)
message_text = f"<|User|>{message.content}"
elif isinstance(message, AIMessage):
message_text = (
f"<|Assistant|>{message.content}"
)
message_text = f"<|Assistant|>{message.content}"
elif isinstance(message, SystemMessage):
message_text = (
f"<|System|>{message.content}"
)
message_text = f"<|System|>{message.content}"
else:
raise ValueError(f"Got unknown type {message}")

Expand Down Expand Up @@ -291,30 +276,22 @@ def convert_messages_to_prompt_writer(messages: List[BaseMessage]) -> str:

def _convert_one_message_to_text_openai(message: BaseMessage) -> str:
if isinstance(message, SystemMessage):
message_text = (
f"<|start|>system<|message|>{message.content}<|end|>"
)
message_text = f"<|start|>system<|message|>{message.content}<|end|>"
elif isinstance(message, ChatMessage):
# developer role messages
message_text = (
f"<|start|>{message.role}<|message|>{message.content}<|end|>"
)
message_text = f"<|start|>{message.role}<|message|>{message.content}<|end|>"
elif isinstance(message, HumanMessage):
message_text = (
f"<|start|>user<|message|>{message.content}<|end|>"
)
message_text = f"<|start|>user<|message|>{message.content}<|end|>"
elif isinstance(message, AIMessage):
message_text = (
f"<|start|>assistant<|channel|>final<|message|>{message.content}<|end|>"
)
elif isinstance(message, ToolMessage):
# TODO: Tool messages in the OpenAI format should use "<|start|>{toolname} to=assistant<|message|>"
# Need to extract the tool name from the ToolMessage content or tool_call_id
# For now using generic "to=assistant" format as placeholder until we implement tool calling
# Will be resolved in follow-up PR with full tool support
message_text = (
f"<|start|>to=assistant<|channel|>commentary<|message|>{message.content}<|end|>"
)
# TODO: Tool messages in the OpenAI format should use "<|start|>{toolname} to=assistant<|message|>"
# Need to extract the tool name from the ToolMessage content or tool_call_id
# For now using generic "to=assistant" format as placeholder until we implement tool calling
# Will be resolved in follow-up PR with full tool support
message_text = f"<|start|>to=assistant<|channel|>commentary<|message|>{message.content}<|end|>"
else:
raise ValueError(f"Got unknown type {message}")

Expand Down Expand Up @@ -373,7 +350,7 @@ def _format_data_content_block(block: dict) -> dict:
"type": "base64",
"media_type": block["mime_type"],
"data": block["data"],
}
},
}
else:
error_message = "Image data only supported through in-line base64 format."
Expand Down Expand Up @@ -440,9 +417,7 @@ def _format_anthropic_messages(
for i, message in enumerate(merged_messages):
if message.type == "system":
if system is not None:
raise ValueError(
"Received multiple non-consecutive system messages."
)
raise ValueError("Received multiple non-consecutive system messages.")
elif isinstance(message.content, str):
system = message.content
elif isinstance(message.content, list):
Expand Down Expand Up @@ -482,9 +457,9 @@ def _format_anthropic_messages(

if not isinstance(message.content, str):
# parse as dict
assert isinstance(
message.content, list
), "Anthropic message content must be str or list of dicts"
assert isinstance(message.content, list), (
"Anthropic message content must be str or list of dicts"
)

# populate content
content = []
Expand Down Expand Up @@ -514,19 +489,28 @@ def _format_anthropic_messages(
# Handle list content inside tool_result
processed_list = []
for list_item in content_item:
if isinstance(list_item, dict) and list_item.get("type") == "image_url":
if (
isinstance(list_item, dict)
and list_item.get("type") == "image_url"
):
# Process image in list
source = _format_image(list_item["image_url"]["url"])
processed_list.append({"type": "image", "source": source})
source = _format_image(
list_item["image_url"]["url"]
)
processed_list.append(
{"type": "image", "source": source}
)
else:
# Keep other items as is
processed_list.append(list_item)
# Add processed list to tool_result
tool_blocks.append({
"type": "tool_result",
"tool_use_id": item.get("tool_use_id"),
"content": processed_list
})
tool_blocks.append(
{
"type": "tool_result",
"tool_use_id": item.get("tool_use_id"),
"content": processed_list,
}
)
else:
# For other content types, keep as is
tool_blocks.append(item)
Expand Down Expand Up @@ -788,9 +772,9 @@ def set_beta_use_converse_api(cls, values: Dict) -> Any:
response = bedrock_client.get_inference_profile(
inferenceProfileIdentifier=model_id
)
if 'models' in response and len(response['models']) > 0:
model_arn = response['models'][0]['modelArn']
resolved_base_model = model_arn.split('/')[-1]
if "models" in response and len(response["models"]) > 0:
model_arn = response["models"][0]["modelArn"]
resolved_base_model = model_arn.split("/")[-1]
values["beta_use_converse_api"] = "nova" in resolved_base_model
return values

Expand Down Expand Up @@ -877,7 +861,7 @@ def _stream(
added_model_name = False
# Track guardrails trace information for callback handling
guardrails_trace_info = None

for chunk in self._prepare_input_and_invoke_stream(
prompt=prompt,
system=system,
Expand All @@ -902,7 +886,7 @@ def _stream(
if services_trace.get("signal") and run_manager:
# Store trace info for potential callback
guardrails_trace_info = services_trace

usage_metadata = generation_info.pop("usage_metadata", None)
response_metadata = generation_info
if not added_model_name:
Expand All @@ -924,7 +908,7 @@ def _stream(
generation_chunk.text, chunk=generation_chunk
)
yield generation_chunk

# If guardrails intervened during streaming, notify the callback handler
if guardrails_trace_info and run_manager:
run_manager.on_llm_error(
Expand Down Expand Up @@ -987,7 +971,9 @@ def _generate(
else:
system = self.system_prompt_with_tools
elif provider == "openai":
formatted_messages = ChatPromptAdapter.format_messages(provider, messages)
formatted_messages = ChatPromptAdapter.format_messages(
provider, messages
)
else:
prompt = ChatPromptAdapter.convert_messages_to_prompt(
provider=provider, messages=messages, model=self._get_base_model()
Expand Down Expand Up @@ -1116,13 +1102,10 @@ def bind_tools(

# Disallow forced tool use when thinking is enabled on specific Claude models
base_model = self._get_base_model()
if (
any(
x in base_model
for x in ("claude-3-7-", "claude-opus-4-", "claude-sonnet-4-")
)
and thinking_in_params(self.model_kwargs or {})
):
if any(
x in base_model
for x in ("claude-3-7-", "claude-opus-4-", "claude-sonnet-4-")
) and thinking_in_params(self.model_kwargs or {}):
forced = False
if isinstance(tool_choice, bool):
forced = bool(tool_choice)
Expand Down
6 changes: 3 additions & 3 deletions libs/aws/langchain_aws/chat_models/bedrock_converse.py
Original file line number Diff line number Diff line change
Expand Up @@ -467,9 +467,9 @@ class Joke(BaseModel):
additionalModelResponseFieldPaths.
"""

supports_tool_choice_values: Optional[
Sequence[Literal["auto", "any", "tool"]]
] = None
supports_tool_choice_values: Optional[Sequence[Literal["auto", "any", "tool"]]] = (
None
)
"""Which types of tool_choice values the model supports.

Inferred if not specified. Inferred as ('auto', 'any', 'tool') if a 'claude-3'
Expand Down
27 changes: 14 additions & 13 deletions libs/aws/langchain_aws/chat_models/sagemaker_endpoint.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
"""Sagemaker Chat Model."""

import io
import logging
from typing import Any, Dict, Iterator, List, Mapping, Optional
Expand Down Expand Up @@ -136,16 +137,16 @@ class ChatSagemakerEndpoint(BaseChatModel):
EC2 instance, credentials from IMDS will be used.

client: boto3 client for Sagemaker Endpoint

endpoint_name: The name of the endpoint from the deployed Sagemaker model.

content_handler: Implementation for model specific ChatContentHandler
content_handler: Implementation for model specific ChatContentHandler


Example:
.. code-block:: python

from langchain_aws.chat_models.sagemaker_endpoint import
from langchain_aws.chat_models.sagemaker_endpoint import
ChatSagemakerEndpoint
endpoint_name = (
"my-endpoint-name"
Expand All @@ -161,7 +162,7 @@ class ChatSagemakerEndpoint(BaseChatModel):
region_name=region_name,
credentials_profile_name=credentials_profile_name
)

# Usage with Inference Component
se = ChatSagemakerEndpoint(
endpoint_name=endpoint_name,
Expand Down Expand Up @@ -190,13 +191,13 @@ class ChatSagemakerEndpoint(BaseChatModel):
Must be unique within an AWS Region."""

inference_component_name: Optional[str] = None
"""Optional name of the inference component to invoke
"""Optional name of the inference component to invoke
if specified with endpoint name."""

region_name: Optional[str] = ""
"""The aws region, e.g., `us-west-2`.
"""The aws region, e.g., `us-west-2`.

Falls back to AWS_REGION or AWS_DEFAULT_REGION env variable or region specified in
Falls back to AWS_REGION or AWS_DEFAULT_REGION env variable or region specified in
~/.aws/config in case it is not provided here.
"""

Expand All @@ -205,14 +206,14 @@ class ChatSagemakerEndpoint(BaseChatModel):

Profile should either have access keys or role information specified.
If not specified, the default credential profile or, if on an EC2 instance,
credentials from IMDS will be used.
credentials from IMDS will be used.
See: https://boto3.amazonaws.com/v1/documentation/api/latest/guide/credentials.html
"""

aws_access_key_id: Optional[SecretStr] = Field(
default_factory=secret_from_env("AWS_ACCESS_KEY_ID", default=None)
)
"""AWS access key id.
"""AWS access key id.

If provided, aws_secret_access_key must also be provided.
If not specified, the default credential profile or, if on an EC2 instance,
Expand All @@ -225,7 +226,7 @@ class ChatSagemakerEndpoint(BaseChatModel):
aws_secret_access_key: Optional[SecretStr] = Field(
default_factory=secret_from_env("AWS_SECRET_ACCESS_KEY", default=None)
)
"""AWS secret_access_key.
"""AWS secret_access_key.

If provided, aws_access_key_id must also be provided.
If not specified, the default credential profile or, if on an EC2 instance,
Expand All @@ -238,9 +239,9 @@ class ChatSagemakerEndpoint(BaseChatModel):
aws_session_token: Optional[SecretStr] = Field(
default_factory=secret_from_env("AWS_SESSION_TOKEN", default=None)
)
"""AWS session token.
"""AWS session token.

If provided, aws_access_key_id and aws_secret_access_key must
If provided, aws_access_key_id and aws_secret_access_key must
also be provided. Not required unless using temporary credentials.
See: https://boto3.amazonaws.com/v1/documentation/api/latest/guide/credentials.html

Expand Down Expand Up @@ -275,7 +276,7 @@ class ContentHandler(ChatContentHandler):
def transform_input(self, prompt: List[Dict[str, Any]], model_kwargs: Dict) -> bytes:
input_str = json.dumps({prompt: prompt, **model_kwargs})
return input_str.encode('utf-8')

def transform_output(self, output: bytes) -> BaseMessage:
response_json = json.loads(output.read().decode("utf-8"))
return response_json[0]["generated_text"]
Expand Down
Loading