diff --git a/libs/aws/langchain_aws/chat_models/bedrock_converse.py b/libs/aws/langchain_aws/chat_models/bedrock_converse.py index 04c9bfd5..72bd0387 100644 --- a/libs/aws/langchain_aws/chat_models/bedrock_converse.py +++ b/libs/aws/langchain_aws/chat_models/bedrock_converse.py @@ -514,6 +514,70 @@ def create_cache_point(cls, cache_type: str = "default") -> Dict[str, Any]: """ return {"cachePoint": {"type": cache_type}} + @classmethod + def create_document( + cls, + name: str, + source: dict[str, Any], + format: Literal["pdf", "csv", "doc", "docx", "xls", "xlsx", "html", "txt", "md"], + context: Optional[str] = None, + enable_citations: Optional[bool] = False, + + ) -> Dict[str, Any]: + """Create a document configuration for Bedrock. + Args: + name: The name of the document. + source: The source of the document. + format: The format of the document, or its extension. + context: Info for the model to understand the document for citations. + enable_citations: Whether to enable the Citations API for the document. + Returns: + Dictionary containing a properly formatted to add to message content.""" + if re.search(r"[^A-Za-z0-9 \[\]()\-]|\s{2,}", name): + raise ValueError( + "Name must be only alphanumeric characters," + " whitespace characters (no more than one in a row)," + " hyphens, parentheses, or square brackets." + ) + + valid_source_types = ["bytes", "content", "s3Location", "text"] + if ( + len(source.keys()) > 1 + or list(source.keys())[0] not in valid_source_types + ): + raise ValueError( + f"The key for source can only be one of the following: {valid_source_types}" + ) + + if source.get("bytes") and not isinstance(source.get("bytes"), bytes): + raise ValueError(f"Document source with type bytes must be bytes type.") + + if source.get("text") and not isinstance(source.get("text"), str): + raise ValueError("Document source with type text must be str type.") + + if source.get("s3Location") and not isinstance( + source.get("s3Location").get("uri"), str + ): + raise ValueError( + "Document source with type s3Location" + " must have a dictionary with a valid s3 uri as a dict." + ) + + if source.get("content") and not isinstance(source.get("content", list)): + raise ValueError( + "Document source with type content must have a list of document content blocks." + ) + + document = {"name": name, "source": source, "format": format} + + if context: + document["context"] = context + + if enable_citations: + document["citations"] = {"enabled": True} + + return {"document": document} + @model_validator(mode="before") @classmethod def build_extra(cls, values: dict[str, Any]) -> Any: diff --git a/libs/aws/tests/integration_tests/chat_models/test_bedrock_converse.py b/libs/aws/tests/integration_tests/chat_models/test_bedrock_converse.py index b0c6c5ae..e41a5fd6 100644 --- a/libs/aws/tests/integration_tests/chat_models/test_bedrock_converse.py +++ b/libs/aws/tests/integration_tests/chat_models/test_bedrock_converse.py @@ -648,3 +648,51 @@ def test_bedrock_pdf_inputs() -> None: ] ) _ = model.invoke([message]) + + +def test_bedrock_document_usage() -> None: + model = ChatBedrockConverse( + model="us.anthropic.claude-3-5-sonnet-20241022-v2:0", region_name="us-west-2" + ) + + # Test bytes source type + url = "https://www.w3.org/WAI/ER/tests/xhtml/testfiles/resources/pdf/dummy.pdf" + pdf_bytes = httpx.get(url).content + message = HumanMessage( + [ + {"type": "text", "text": "Summarize this document:"}, + ChatBedrockConverse.create_document( + "PDFDoc", source={"bytes": pdf_bytes}, format="pdf" + ), + ] + ) + + _ = model.invoke([message]) + + # Test text source type + text = "I am a text document." + message = HumanMessage( + [ + {"type": "text", "text": "Summarize this document:"}, + ChatBedrockConverse.create_document( + "TextDoc", source={"text": text}, format="txt" + ), + ] + ) + _ = model.invoke([message]) + + # Test content source type + split_text = [ + {"text": "I am the first part of a document."}, + {"text": "I am the second part."}, + {"text": "I am not sure how I got here."}, + ] + message = HumanMessage( + [ + {"type": "text", "text": "Summarize this document:"}, + ChatBedrockConverse.create_document( + "TextDoc", source={"content": split_text}, format="txt" + ), + ] + ) + _ = model.invoke([message]) diff --git a/libs/aws/tests/unit_tests/chat_models/test_bedrock_converse.py b/libs/aws/tests/unit_tests/chat_models/test_bedrock_converse.py index ce739ae3..a1cb84a7 100644 --- a/libs/aws/tests/unit_tests/chat_models/test_bedrock_converse.py +++ b/libs/aws/tests/unit_tests/chat_models/test_bedrock_converse.py @@ -1515,6 +1515,22 @@ def test_create_cache_point() -> None: assert cache_point["cachePoint"]["type"] == "default" +def test_create_document() -> None: + """Test creating a document.""" + document = ChatBedrockConverse.create_document( + name="MyDoc", source={"text": "Cite me"}, format="txt", enable_citations=True + ) + expected_doc = { + "document": { + "name": "MyDoc", + "source": {"text": "Cite me"}, + "format": "txt", + "citations": {"enabled": True}, + } + } + assert document == expected_doc + + def test_anthropic_tool_with_cache_point() -> None: """Test convert_to_anthropic_tool with cache point""" # Test with cache point