Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
64 changes: 64 additions & 0 deletions libs/aws/langchain_aws/chat_models/bedrock_converse.py
Original file line number Diff line number Diff line change
Expand Up @@ -514,6 +514,70 @@ def create_cache_point(cls, cache_type: str = "default") -> Dict[str, Any]:
"""
return {"cachePoint": {"type": cache_type}}

@classmethod
def create_document(
cls,
name: str,
source: dict[str, Any],
format: Literal["pdf", "csv", "doc", "docx", "xls", "xlsx", "html", "txt", "md"],
context: Optional[str] = None,
enable_citations: Optional[bool] = False,

) -> Dict[str, Any]:
"""Create a document configuration for Bedrock.
Args:
name: The name of the document.
source: The source of the document.
format: The format of the document, or its extension.
context: Info for the model to understand the document for citations.
enable_citations: Whether to enable the Citations API for the document.
Returns:
Dictionary containing a properly formatted to add to message content."""
if re.search(r"[^A-Za-z0-9 \[\]()\-]|\s{2,}", name):
raise ValueError(
"Name must be only alphanumeric characters,"
" whitespace characters (no more than one in a row),"
" hyphens, parentheses, or square brackets."
)

valid_source_types = ["bytes", "content", "s3Location", "text"]
if (
len(source.keys()) > 1
or list(source.keys())[0] not in valid_source_types
):
raise ValueError(
f"The key for source can only be one of the following: {valid_source_types}"
)

if source.get("bytes") and not isinstance(source.get("bytes"), bytes):
raise ValueError(f"Document source with type bytes must be bytes type.")

if source.get("text") and not isinstance(source.get("text"), str):
raise ValueError("Document source with type text must be str type.")

if source.get("s3Location") and not isinstance(
source.get("s3Location").get("uri"), str
):
raise ValueError(
"Document source with type s3Location"
" must have a dictionary with a valid s3 uri as a dict."
)

if source.get("content") and not isinstance(source.get("content", list)):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Passing content source currently fails because isinstance is missing the second argument for type here

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

whoops good catch thanks!

raise ValueError(
"Document source with type content must have a list of document content blocks."
)

document = {"name": name, "source": source, "format": format}

if context:
document["context"] = context

if enable_citations:
document["citations"] = {"enabled": True}

return {"document": document}

@model_validator(mode="before")
@classmethod
def build_extra(cls, values: dict[str, Any]) -> Any:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -648,3 +648,51 @@ def test_bedrock_pdf_inputs() -> None:
]
)
_ = model.invoke([message])


def test_bedrock_document_usage() -> None:
model = ChatBedrockConverse(
model="us.anthropic.claude-3-5-sonnet-20241022-v2:0", region_name="us-west-2"
)

# Test bytes source type
url = "https://www.w3.org/WAI/ER/tests/xhtml/testfiles/resources/pdf/dummy.pdf"
pdf_bytes = httpx.get(url).content
message = HumanMessage(
[
{"type": "text", "text": "Summarize this document:"},
ChatBedrockConverse.create_document(
"PDFDoc", source={"bytes": pdf_bytes}, format="pdf"
),
]
)

_ = model.invoke([message])

# Test text source type
text = "I am a text document."
message = HumanMessage(
[
{"type": "text", "text": "Summarize this document:"},
ChatBedrockConverse.create_document(
"TextDoc", source={"text": text}, format="txt"
),
]
)
_ = model.invoke([message])

# Test content source type
split_text = [
{"text": "I am the first part of a document."},
{"text": "I am the second part."},
{"text": "I am not sure how I got here."},
]
message = HumanMessage(
[
{"type": "text", "text": "Summarize this document:"},
ChatBedrockConverse.create_document(
"TextDoc", source={"content": split_text}, format="txt"
),
]
)
_ = model.invoke([message])
16 changes: 16 additions & 0 deletions libs/aws/tests/unit_tests/chat_models/test_bedrock_converse.py
Original file line number Diff line number Diff line change
Expand Up @@ -1515,6 +1515,22 @@ def test_create_cache_point() -> None:
assert cache_point["cachePoint"]["type"] == "default"


def test_create_document() -> None:
"""Test creating a document."""
document = ChatBedrockConverse.create_document(
name="MyDoc", source={"text": "Cite me"}, format="txt", enable_citations=True
)
expected_doc = {
"document": {
"name": "MyDoc",
"source": {"text": "Cite me"},
"format": "txt",
"citations": {"enabled": True},
}
}
assert document == expected_doc


def test_anthropic_tool_with_cache_point() -> None:
"""Test convert_to_anthropic_tool with cache point"""
# Test with cache point
Expand Down