From 4f4c25c18e72da3ed3be8a2432720b47d305c69e Mon Sep 17 00:00:00 2001 From: mathewHatch <41880891+mathewHatch@users.noreply.github.com> Date: Sun, 31 Aug 2025 20:21:29 -0500 Subject: [PATCH 01/13] feat: add create_document class method to add easily add documents to message content --- .../chat_models/bedrock_converse.py | 131 ++++++++++++++---- 1 file changed, 106 insertions(+), 25 deletions(-) diff --git a/libs/aws/langchain_aws/chat_models/bedrock_converse.py b/libs/aws/langchain_aws/chat_models/bedrock_converse.py index 96b7be83..bf162cb1 100644 --- a/libs/aws/langchain_aws/chat_models/bedrock_converse.py +++ b/libs/aws/langchain_aws/chat_models/bedrock_converse.py @@ -65,7 +65,7 @@ MIME_TO_FORMAT = { # Image formats "image/png": "png", - "image/jpeg": "jpeg", + "image/jpeg": "jpeg", "image/gif": "gif", "image/webp": "webp", # File formats @@ -465,9 +465,9 @@ class Joke(BaseModel): additionalModelResponseFieldPaths. """ - supports_tool_choice_values: Optional[ - Sequence[Literal["auto", "any", "tool"]] - ] = None + supports_tool_choice_values: Optional[Sequence[Literal["auto", "any", "tool"]]] = ( + None + ) """Which types of tool_choice values the model supports. Inferred if not specified. Inferred as ('auto', 'any', 'tool') if a 'claude-3' @@ -512,6 +512,74 @@ def create_cache_point(cls, cache_type: str = "default") -> Dict[str, Any]: """ return {"cachePoint": {"type": cache_type}} + @classmethod + def create_document( + cls, + name: str, + source: dict[str, Any], + context: Optional[str] = None, + enable_citations: Optional[bool] = False, + format: Optional[ + Literal["pdf", "csv", "doc", "docx", "xls", "xlsx", "html", "txt", "md"] + ] = None, + ) -> Dict[str, Any]: + """Create a document configuration for Bedrock. + Args: + name: The name of the document. + source: The source of the document. + context: Info for the model to understand the document for citations. + format: The format of the document, or its extension. + Returns: + Dictionary containing a properly formatted to add to message content.""" + if re.match(r"[^\w\[\]\(\)-]|[\s]{2,}", name): + raise ValueError( + "Name must be only alphanumeric characters," + " whitespace characters (no more than one in a row)," + " hyphens, parantheses, or square brackets." + ) + + valid_source_types = ["bytes", "content", "s3Location", "text"] + if ( + len(source.keys()) > 1 + or source.keys() + or source.keys()[0] not in valid_source_types + ): + raise ValueError( + f"The key for source can only be one of the following: {valid_source_types}" + ) + + if source.get("bytes") and not isinstance(source.get("bytes"), bytes): + raise ValueError(f"Document source with type bytes must be bytes type.") + + if source.get("text") and not isinstance(source.get("text"), str): + raise ValueError("Document source with type text must be str type.") + + if source.get("s3Location") and not isinstance( + source.get("s3Location").get("uri"), str + ): + raise ValueError( + "Document source with type s3Location" + " must have a dictionary with a valid s3 uri as a dict." + ) + + if source.get("content") and not isinstance(source.get("content", list)): + raise ValueError( + "Document source with type content must have a list of document content blocks." + ) + + document = {"name": name, "source": source} + + if context: + document["context"] = context + + if format: + document["format"] = format + + if enable_citations: + document["citation"] = {"enabled": True} + + return {"document": document} + @model_validator(mode="before") @classmethod def build_extra(cls, values: dict[str, Any]) -> Any: @@ -533,9 +601,11 @@ def build_extra(cls, values: dict[str, Any]) -> Any: return values @classmethod - def _get_streaming_support(cls, provider: str, model_id_lower: str) -> Union[bool, str]: + def _get_streaming_support( + cls, provider: str, model_id_lower: str + ) -> Union[bool, str]: """Determine streaming support for a given provider and model. - + Returns: True: Full streaming support "no_tools": Streaming supported but not with tools @@ -612,7 +682,7 @@ def _get_streaming_support(cls, provider: str, model_id_lower: str) -> Union[boo @classmethod def set_disable_streaming(cls, values: Dict) -> Any: model_id = values.get("model_id", values.get("model")) - + # Extract provider from the model_id # (e.g., "amazon", "anthropic", "ai21", "meta", "mistral") if "provider" not in values or values["provider"] == "": @@ -652,8 +722,8 @@ def set_disable_streaming(cls, values: Dict) -> Any: @model_validator(mode="after") def validate_environment(self) -> Self: """Validate that AWS credentials to and python package exists in environment.""" - - # Create bedrock client for control plane API call + + # Create bedrock client for control plane API call if self.bedrock_client is None: self.bedrock_client = create_aws_client( region_name=self.region_name, @@ -665,7 +735,7 @@ def validate_environment(self) -> Self: config=self.config, service_name="bedrock", ) - + # Handle streaming configuration for application inference profiles if "application-inference-profile" in self.model_id: self._configure_streaming_for_resolved_model() @@ -712,27 +782,30 @@ def validate_environment(self) -> Self: "Provide a guardrail via `guardrail_config` or " "disable `guard_last_turn_only`." ) - + return self def _get_base_model(self) -> str: # identify the base model id used in the application inference profile (AIP) # Format: arn:aws:bedrock:us-east-1::application-inference-profile/ - if self.base_model_id is None and 'application-inference-profile' in self.model_id: + if ( + self.base_model_id is None + and "application-inference-profile" in self.model_id + ): response = self.bedrock_client.get_inference_profile( inferenceProfileIdentifier=self.model_id ) - if 'models' in response and len(response['models']) > 0: - model_arn = response['models'][0]['modelArn'] + if "models" in response and len(response["models"]) > 0: + model_arn = response["models"][0]["modelArn"] # Format: arn:aws:bedrock:region::foundation-model/provider.model-name - self.base_model_id = model_arn.split('/')[-1] + self.base_model_id = model_arn.split("/")[-1] return self.base_model_id if self.base_model_id else self.model_id - + def _configure_streaming_for_resolved_model(self) -> None: """Configure streaming support after resolving the base model for application inference profiles.""" base_model = self._get_base_model() model_id_lower = base_model.lower() - + streaming_support = self._get_streaming_support(self.provider, model_id_lower) # Set the disable_streaming flag accordingly @@ -1194,7 +1267,7 @@ def _parse_stream_event(event: Dict[str, Any]) -> Optional[BaseMessageChunk]: ) # always keep block inside a list to preserve merging compatibility content = [block] - + return AIMessageChunk(content=content, tool_call_chunks=tool_call_chunks) elif "contentBlockDelta" in event: block = { @@ -1213,7 +1286,7 @@ def _parse_stream_event(event: Dict[str, Any]) -> Optional[BaseMessageChunk]: ) # always keep block inside a list to preserve merging compatibility content = [block] - + return AIMessageChunk(content=content, tool_call_chunks=tool_call_chunks) elif "contentBlockStop" in event: # TODO: needed? @@ -1244,13 +1317,13 @@ def _mime_type_to_format(mime_type: str) -> str: if mime_type in MIME_TO_FORMAT: return MIME_TO_FORMAT[mime_type] - + # Fallback to original method of splitting on "/" for simple cases all_formats = set(MIME_TO_FORMAT.values()) format_part = mime_type.split("/")[1] if format_part in all_formats: return format_part - + raise ValueError( f"Unsupported MIME type: {mime_type}. Please refer to the Bedrock Converse API documentation for supported formats." ) @@ -1327,7 +1400,9 @@ def _lc_content_to_bedrock( ): bedrock_content.append(_format_data_content_block(block)) elif block["type"] == "text": - if not block["text"] or (isinstance(block["text"], str) and block["text"].isspace()): + if not block["text"] or ( + isinstance(block["text"], str) and block["text"].isspace() + ): bedrock_content.append({"text": "."}) else: bedrock_content.append({"text": block["text"]}) @@ -1339,7 +1414,9 @@ def _lc_content_to_bedrock( bedrock_content.append( { "image": { - "format": _mime_type_to_format(block["source"]["mediaType"]), + "format": _mime_type_to_format( + block["source"]["mediaType"] + ), "source": { "bytes": _b64str_to_bytes(block["source"]["data"]) }, @@ -1360,7 +1437,9 @@ def _lc_content_to_bedrock( bedrock_content.append( { "video": { - "format": _mime_type_to_format(block["source"]["mediaType"]), + "format": _mime_type_to_format( + block["source"]["mediaType"] + ), "source": { "bytes": _b64str_to_bytes(block["source"]["data"]) }, @@ -1371,7 +1450,9 @@ def _lc_content_to_bedrock( bedrock_content.append( { "video": { - "format": _mime_type_to_format(block["source"]["mediaType"]), + "format": _mime_type_to_format( + block["source"]["mediaType"] + ), "source": {"s3Location": block["source"]["data"]}, } } From 13fe75988a85b4f75baf68a40cb4984a193638ec Mon Sep 17 00:00:00 2001 From: mathewHatch <41880891+mathewHatch@users.noreply.github.com> Date: Sun, 31 Aug 2025 20:22:15 -0500 Subject: [PATCH 02/13] test: add testing for create_document with simple document --- .../chat_models/test_bedrock_converse.py | 246 ++++++++++-------- 1 file changed, 133 insertions(+), 113 deletions(-) diff --git a/libs/aws/tests/unit_tests/chat_models/test_bedrock_converse.py b/libs/aws/tests/unit_tests/chat_models/test_bedrock_converse.py index c3b1d0ea..a5f0d9e4 100644 --- a/libs/aws/tests/unit_tests/chat_models/test_bedrock_converse.py +++ b/libs/aws/tests/unit_tests/chat_models/test_bedrock_converse.py @@ -487,8 +487,7 @@ def test__snake_to_camel_keys() -> None: assert _snake_to_camel_keys(_SNAKE_DICT) == _CAMEL_DICT -def test__format_openai_image_url() -> None: - ... +def test__format_openai_image_url() -> None: ... def test_standard_tracing_params() -> None: @@ -1113,7 +1112,7 @@ def test__lc_content_to_bedrock_mime_types() -> None: video_data = base64.b64encode(b"video_test_data").decode("utf-8") image_data = base64.b64encode(b"image_test_data").decode("utf-8") file_data = base64.b64encode(b"file_test_data").decode("utf-8") - + # Create content with one of each type content: List[Union[str, Dict[str, Any]]] = [ { @@ -1140,31 +1139,25 @@ def test__lc_content_to_bedrock_mime_types() -> None: "name": "test_document.pdf", }, ] - + expected_content = [ { "video": { "format": "mp4", - "source": { - "bytes": base64.b64decode(video_data.encode("utf-8")) - }, + "source": {"bytes": base64.b64decode(video_data.encode("utf-8"))}, } }, { "image": { "format": "jpeg", - "source": { - "bytes": base64.b64decode(image_data.encode("utf-8")) - }, + "source": {"bytes": base64.b64decode(image_data.encode("utf-8"))}, } }, { "document": { "format": "pdf", "name": "test_document.pdf", - "source": { - "bytes": base64.b64decode(file_data.encode("utf-8")) - }, + "source": {"bytes": base64.b64decode(file_data.encode("utf-8"))}, } }, ] @@ -1175,52 +1168,56 @@ def test__lc_content_to_bedrock_mime_types() -> None: def test__lc_content_to_bedrock_mime_types_invalid() -> None: with pytest.raises(ValueError, match="Invalid MIME type format"): - _lc_content_to_bedrock([ - { - "type": "image", - "source": { - "type": "base64", - "mediaType": "invalidmimetype", - "data": base64.b64encode(b"test_data").decode("utf-8"), - }, - } - ]) - + _lc_content_to_bedrock( + [ + { + "type": "image", + "source": { + "type": "base64", + "mediaType": "invalidmimetype", + "data": base64.b64encode(b"test_data").decode("utf-8"), + }, + } + ] + ) + with pytest.raises(ValueError, match="Unsupported MIME type"): - _lc_content_to_bedrock([ - { - "type": "file", - "sourceType": "base64", - "mimeType": "application/unknown-format", - "data": base64.b64encode(b"test_data").decode("utf-8"), - "name": "test_document.xyz", - } - ]) + _lc_content_to_bedrock( + [ + { + "type": "file", + "sourceType": "base64", + "mimeType": "application/unknown-format", + "data": base64.b64encode(b"test_data").decode("utf-8"), + "name": "test_document.xyz", + } + ] + ) def test__lc_content_to_bedrock_empty_content() -> None: content: List[Union[str, Dict[str, Any]]] = [] - + bedrock_content = _lc_content_to_bedrock(content) - + assert len(bedrock_content) > 0 assert bedrock_content[0]["text"] == "." def test__lc_content_to_bedrock_whitespace_only_content() -> None: content = " \n \t " - + bedrock_content = _lc_content_to_bedrock(content) - + assert len(bedrock_content) > 0 assert bedrock_content[0]["text"] == "." def test__lc_content_to_bedrock_empty_string_content() -> None: content = "" - + bedrock_content = _lc_content_to_bedrock(content) - + assert len(bedrock_content) > 0 assert bedrock_content[0]["text"] == "." @@ -1229,9 +1226,9 @@ def test__lc_content_to_bedrock_mixed_empty_content() -> None: content: List[Union[str, Dict[str, Any]]] = [ {"type": "text", "text": ""}, {"type": "text", "text": " "}, - {"type": "text", "text": ""} + {"type": "text", "text": ""}, ] - + bedrock_content = _lc_content_to_bedrock(content) assert len(bedrock_content) > 0 @@ -1239,23 +1236,19 @@ def test__lc_content_to_bedrock_mixed_empty_content() -> None: def test__lc_content_to_bedrock_empty_text_block() -> None: - content: List[Union[str, Dict[str, Any]]] = [ - {"type": "text", "text": ""} - ] - + content: List[Union[str, Dict[str, Any]]] = [{"type": "text", "text": ""}] + bedrock_content = _lc_content_to_bedrock(content) - + assert len(bedrock_content) > 0 assert bedrock_content[0]["text"] == "." def test__lc_content_to_bedrock_whitespace_text_block() -> None: - content: List[Union[str, Dict[str, Any]]] = [ - {"type": "text", "text": " \n "} - ] - + content: List[Union[str, Dict[str, Any]]] = [{"type": "text", "text": " \n "}] + bedrock_content = _lc_content_to_bedrock(content) - + assert len(bedrock_content) > 0 assert bedrock_content[0]["text"] == "." @@ -1264,9 +1257,9 @@ def test__lc_content_to_bedrock_mixed_valid_and_empty_content() -> None: content: List[Union[str, Dict[str, Any]]] = [ {"type": "text", "text": "Valid text"}, {"type": "text", "text": ""}, - {"type": "text", "text": " "} + {"type": "text", "text": " "}, ] - + bedrock_content = _lc_content_to_bedrock(content) assert len(bedrock_content) == 3 @@ -1284,21 +1277,21 @@ def test__lc_content_to_bedrock_mixed_types_with_empty_content() -> None: "input": {"arg1": "val1"}, "name": "tool1", }, - {"type": "text", "text": " "} + {"type": "text", "text": " "}, ] expected = [ - {'text': 'Valid text'}, + {"text": "Valid text"}, { - 'toolUse': { - 'toolUseId': 'tool_call1', - 'input': {'arg1': 'val1'}, - 'name': 'tool1' + "toolUse": { + "toolUseId": "tool_call1", + "input": {"arg1": "val1"}, + "name": "tool1", } }, - {'text': '.'} + {"text": "."}, ] - + bedrock_content = _lc_content_to_bedrock(content) assert len(bedrock_content) == 3 @@ -1431,6 +1424,21 @@ def test_create_cache_point() -> None: assert cache_point["cachePoint"]["type"] == "default" +def test_create_document() -> None: + """Test creating a document.""" + document = ChatBedrockConverse.create_document( + name="MyDoc", source={"text": "Cite me"}, enable_citations=True + ) + expected_doc = { + "document": { + "name": "MyDoc", + "source": {"text": "Cite me"}, + "citations": {"enabled": True}, + } + } + assert document == expected_doc + + def test_anthropic_tool_with_cache_point() -> None: """Test convert_to_anthropic_tool with cache point""" # Test with cache point @@ -1507,9 +1515,9 @@ def test_model_kwargs() -> None: assert llm.temperature is None -def _create_mock_llm_guard_last_turn_only() -> ( - Tuple[ChatBedrockConverse, mock.MagicMock] -): +def _create_mock_llm_guard_last_turn_only() -> Tuple[ + ChatBedrockConverse, mock.MagicMock +]: """Utility to create an LLM with guard_last_turn_only=True and a mocked client.""" mocked_client = mock.MagicMock() llm = ChatBedrockConverse( @@ -1582,61 +1590,63 @@ def test_stream_guard_last_turn_only() -> None: "guardContent": {"text": {"text": "How are you?"}} } + @mock.patch("langchain_aws.chat_models.bedrock_converse.create_aws_client") def test_bedrock_client_creation(mock_create_client: mock.Mock) -> None: """Test that bedrock_client is created during validation.""" mock_bedrock_client = mock.Mock() mock_runtime_client = mock.Mock() - + def side_effect(service_name: str, **kwargs: Any) -> mock.Mock: if service_name == "bedrock": return mock_bedrock_client elif service_name == "bedrock-runtime": return mock_runtime_client return mock.Mock() - + mock_create_client.side_effect = side_effect - + chat_model = ChatBedrockConverse( - model="anthropic.claude-3-sonnet-20240229-v1:0", - region_name="us-west-2" + model="anthropic.claude-3-sonnet-20240229-v1:0", region_name="us-west-2" ) - + assert chat_model.bedrock_client == mock_bedrock_client assert chat_model.client == mock_runtime_client assert mock_create_client.call_count == 2 @mock.patch("langchain_aws.chat_models.bedrock_converse.create_aws_client") -def test_get_base_model_with_application_inference_profile(mock_create_client: mock.Mock) -> None: +def test_get_base_model_with_application_inference_profile( + mock_create_client: mock.Mock, +) -> None: """Test _get_base_model method with application inference profile.""" mock_bedrock_client = mock.Mock() mock_runtime_client = mock.Mock() - + # Mock the get_inference_profile response mock_bedrock_client.get_inference_profile.return_value = { - 'models': [ + "models": [ { - 'modelArn': 'arn:aws:bedrock:us-east-1::foundation-model/anthropic.claude-3-sonnet-20240229-v1:0' + "modelArn": "arn:aws:bedrock:us-east-1::foundation-model/anthropic.claude-3-sonnet-20240229-v1:0" } ] } - + def side_effect(service_name: str, **kwargs: Any) -> mock.Mock: if service_name == "bedrock": return mock_bedrock_client elif service_name == "bedrock-runtime": return mock_runtime_client return mock.Mock() - + mock_create_client.side_effect = side_effect - + chat_model = ChatBedrockConverse( model="arn:aws:bedrock:us-east-1:123456789012:application-inference-profile/test-profile", region_name="us-west-2", - provider="anthropic" + provider="anthropic", ) - + base_model = chat_model._get_base_model() assert base_model == "anthropic.claude-3-sonnet-20240229-v1:0" mock_bedrock_client.get_inference_profile.assert_called_once_with( @@ -1645,26 +1655,28 @@ def side_effect(service_name: str, **kwargs: Any) -> mock.Mock: @mock.patch("langchain_aws.chat_models.bedrock_converse.create_aws_client") -def test_get_base_model_without_application_inference_profile(mock_create_client: mock.Mock) -> None: +def test_get_base_model_without_application_inference_profile( + mock_create_client: mock.Mock, +) -> None: """Test _get_base_model method without application inference profile.""" mock_bedrock_client = mock.Mock() mock_runtime_client = mock.Mock() - + def side_effect(service_name: str, **kwargs: Any) -> mock.Mock: if service_name == "bedrock": return mock_bedrock_client elif service_name == "bedrock-runtime": return mock_runtime_client return mock.Mock() - + mock_create_client.side_effect = side_effect - + chat_model = ChatBedrockConverse( model="anthropic.claude-3-sonnet-20240229-v1:0", region_name="us-west-2", - provider="anthropic" + provider="anthropic", ) - + base_model = chat_model._get_base_model() assert base_model == "anthropic.claude-3-sonnet-20240229-v1:0" mock_bedrock_client.get_inference_profile.assert_not_called() @@ -1675,104 +1687,112 @@ def test_configure_streaming_for_resolved_model(mock_create_client: mock.Mock) - """Test _configure_streaming_for_resolved_model method.""" mock_bedrock_client = mock.Mock() mock_runtime_client = mock.Mock() - + # Mock the get_inference_profile response for a model with full streaming support mock_bedrock_client.get_inference_profile.return_value = { - 'models': [ + "models": [ { - 'modelArn': 'arn:aws:bedrock:us-east-1::foundation-model/anthropic.claude-3-sonnet-20240229-v1:0' + "modelArn": "arn:aws:bedrock:us-east-1::foundation-model/anthropic.claude-3-sonnet-20240229-v1:0" } ] } - + def side_effect(service_name: str, **kwargs: Any) -> mock.Mock: if service_name == "bedrock": return mock_bedrock_client elif service_name == "bedrock-runtime": return mock_runtime_client return mock.Mock() - + mock_create_client.side_effect = side_effect - + chat_model = ChatBedrockConverse( model="arn:aws:bedrock:us-east-1:123456789012:application-inference-profile/test-profile", region_name="us-west-2", - provider="anthropic" + provider="anthropic", ) - + # The streaming should be configured based on the resolved model assert chat_model.disable_streaming is False @mock.patch("langchain_aws.chat_models.bedrock_converse.create_aws_client") -def test_configure_streaming_for_resolved_model_no_tools(mock_create_client: mock.Mock) -> None: +def test_configure_streaming_for_resolved_model_no_tools( + mock_create_client: mock.Mock, +) -> None: """Test _configure_streaming_for_resolved_model method with no-tools streaming.""" mock_bedrock_client = mock.Mock() mock_runtime_client = mock.Mock() - + # Mock the get_inference_profile response for a model with no-tools streaming support mock_bedrock_client.get_inference_profile.return_value = { - 'models': [ + "models": [ { - 'modelArn': 'arn:aws:bedrock:us-east-1::foundation-model/amazon.titan-text-express-v1' + "modelArn": "arn:aws:bedrock:us-east-1::foundation-model/amazon.titan-text-express-v1" } ] } - + def side_effect(service_name: str, **kwargs: Any) -> mock.Mock: if service_name == "bedrock": return mock_bedrock_client elif service_name == "bedrock-runtime": return mock_runtime_client return mock.Mock() - + mock_create_client.side_effect = side_effect - + chat_model = ChatBedrockConverse( model="arn:aws:bedrock:us-east-1:123456789012:application-inference-profile/test-profile", region_name="us-west-2", - provider="amazon" + provider="amazon", ) - + # The streaming should be configured as "tool_calling" for no-tools models assert chat_model.disable_streaming == "tool_calling" @mock.patch("langchain_aws.chat_models.bedrock_converse.create_aws_client") -def test_configure_streaming_for_resolved_model_no_streaming(mock_create_client: mock.Mock) -> None: +def test_configure_streaming_for_resolved_model_no_streaming( + mock_create_client: mock.Mock, +) -> None: """Test _configure_streaming_for_resolved_model method with no streaming support.""" mock_bedrock_client = mock.Mock() mock_runtime_client = mock.Mock() - + # Mock the get_inference_profile response for a model with no streaming support mock_bedrock_client.get_inference_profile.return_value = { - 'models': [ + "models": [ { - 'modelArn': 'arn:aws:bedrock:us-east-1::foundation-model/stability.stable-image-core-v1:0' + "modelArn": "arn:aws:bedrock:us-east-1::foundation-model/stability.stable-image-core-v1:0" } ] } - + def side_effect(service_name: str, **kwargs: Any) -> mock.Mock: if service_name == "bedrock": return mock_bedrock_client elif service_name == "bedrock-runtime": return mock_runtime_client return mock.Mock() - + mock_create_client.side_effect = side_effect - + chat_model = ChatBedrockConverse( model="arn:aws:bedrock:us-east-1:123456789012:application-inference-profile/test-profile", region_name="us-west-2", - provider="stability" + provider="stability", ) - + # The streaming should be disabled for models with no streaming support assert chat_model.disable_streaming is True - + def test_nova_provider_extraction() -> None: """Test that provider is correctly extracted from Nova model ID when not provided.""" - model = ChatBedrockConverse(client=mock.MagicMock(), model="us.amazon.nova-pro-v1:0", region_name="us-west-2") + model = ChatBedrockConverse( + client=mock.MagicMock(), + model="us.amazon.nova-pro-v1:0", + region_name="us-west-2", + ) assert model.provider == "amazon" From 71a21ff7a8d819a36b450f768292205edc2e829f Mon Sep 17 00:00:00 2001 From: mathewHatch <41880891+mathewHatch@users.noreply.github.com> Date: Mon, 1 Sep 2025 11:33:05 -0500 Subject: [PATCH 03/13] fix: fixed bugs in code --- libs/aws/langchain_aws/chat_models/bedrock_converse.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/libs/aws/langchain_aws/chat_models/bedrock_converse.py b/libs/aws/langchain_aws/chat_models/bedrock_converse.py index bf162cb1..e39485de 100644 --- a/libs/aws/langchain_aws/chat_models/bedrock_converse.py +++ b/libs/aws/langchain_aws/chat_models/bedrock_converse.py @@ -541,8 +541,7 @@ def create_document( valid_source_types = ["bytes", "content", "s3Location", "text"] if ( len(source.keys()) > 1 - or source.keys() - or source.keys()[0] not in valid_source_types + or list(source.keys())[0] not in valid_source_types ): raise ValueError( f"The key for source can only be one of the following: {valid_source_types}" @@ -576,7 +575,7 @@ def create_document( document["format"] = format if enable_citations: - document["citation"] = {"enabled": True} + document["citations"] = {"enabled": True} return {"document": document} From b4d82382b51087c5ab75a487f4b12e13ea2cedd9 Mon Sep 17 00:00:00 2001 From: hatMatch <116861403+hatMatch@users.noreply.github.com> Date: Thu, 4 Sep 2025 21:02:38 -0500 Subject: [PATCH 04/13] Update libs/aws/langchain_aws/chat_models/bedrock_converse.py Co-authored-by: Michael Chin --- libs/aws/langchain_aws/chat_models/bedrock_converse.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/libs/aws/langchain_aws/chat_models/bedrock_converse.py b/libs/aws/langchain_aws/chat_models/bedrock_converse.py index e39485de..34c89fea 100644 --- a/libs/aws/langchain_aws/chat_models/bedrock_converse.py +++ b/libs/aws/langchain_aws/chat_models/bedrock_converse.py @@ -531,7 +531,7 @@ def create_document( format: The format of the document, or its extension. Returns: Dictionary containing a properly formatted to add to message content.""" - if re.match(r"[^\w\[\]\(\)-]|[\s]{2,}", name): + if not re.search(r"[^A-Za-z0-9 \[\]()\-]|\s{2,}", name): raise ValueError( "Name must be only alphanumeric characters," " whitespace characters (no more than one in a row)," From 58c7ecba0b9dcb045f75419f00189313c9cf3f8a Mon Sep 17 00:00:00 2001 From: mathewHatch <41880891+mathewHatch@users.noreply.github.com> Date: Tue, 16 Sep 2025 11:15:36 -0500 Subject: [PATCH 05/13] fix: changed format to be required arg in create_document --- libs/aws/langchain_aws/chat_models/bedrock_converse.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/libs/aws/langchain_aws/chat_models/bedrock_converse.py b/libs/aws/langchain_aws/chat_models/bedrock_converse.py index 34c89fea..79dfc697 100644 --- a/libs/aws/langchain_aws/chat_models/bedrock_converse.py +++ b/libs/aws/langchain_aws/chat_models/bedrock_converse.py @@ -517,11 +517,10 @@ def create_document( cls, name: str, source: dict[str, Any], + format: Literal["pdf", "csv", "doc", "docx", "xls", "xlsx", "html", "txt", "md"], context: Optional[str] = None, enable_citations: Optional[bool] = False, - format: Optional[ - Literal["pdf", "csv", "doc", "docx", "xls", "xlsx", "html", "txt", "md"] - ] = None, + ) -> Dict[str, Any]: """Create a document configuration for Bedrock. Args: From 4812a051d75a468093eafb582eaa3ac73a40aa35 Mon Sep 17 00:00:00 2001 From: mathewHatch <41880891+mathewHatch@users.noreply.github.com> Date: Tue, 16 Sep 2025 11:16:39 -0500 Subject: [PATCH 06/13] fix: added enable_citations in create_document docstring for completeness --- libs/aws/langchain_aws/chat_models/bedrock_converse.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/libs/aws/langchain_aws/chat_models/bedrock_converse.py b/libs/aws/langchain_aws/chat_models/bedrock_converse.py index 79dfc697..9c5e8b77 100644 --- a/libs/aws/langchain_aws/chat_models/bedrock_converse.py +++ b/libs/aws/langchain_aws/chat_models/bedrock_converse.py @@ -526,8 +526,9 @@ def create_document( Args: name: The name of the document. source: The source of the document. - context: Info for the model to understand the document for citations. format: The format of the document, or its extension. + context: Info for the model to understand the document for citations. + enable_citations: Whether to enable the Citations API for the document. Returns: Dictionary containing a properly formatted to add to message content.""" if not re.search(r"[^A-Za-z0-9 \[\]()\-]|\s{2,}", name): From 003d288514ae4dc2e2cd9854621392b44a2d2b78 Mon Sep 17 00:00:00 2001 From: mathewHatch <41880891+mathewHatch@users.noreply.github.com> Date: Tue, 16 Sep 2025 11:17:43 -0500 Subject: [PATCH 07/13] fix: cleaned up format code as it is now required --- libs/aws/langchain_aws/chat_models/bedrock_converse.py | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/libs/aws/langchain_aws/chat_models/bedrock_converse.py b/libs/aws/langchain_aws/chat_models/bedrock_converse.py index 9c5e8b77..5ab980da 100644 --- a/libs/aws/langchain_aws/chat_models/bedrock_converse.py +++ b/libs/aws/langchain_aws/chat_models/bedrock_converse.py @@ -566,14 +566,11 @@ def create_document( "Document source with type content must have a list of document content blocks." ) - document = {"name": name, "source": source} + document = {"name": name, "source": source, "format": format} if context: document["context"] = context - if format: - document["format"] = format - if enable_citations: document["citations"] = {"enabled": True} From 114bc30e983e43e9aa41ae9968250f32a23fe61a Mon Sep 17 00:00:00 2001 From: mathewHatch <41880891+mathewHatch@users.noreply.github.com> Date: Tue, 16 Sep 2025 12:18:04 -0500 Subject: [PATCH 08/13] fix: updated test_create_document to have a format as it is required --- libs/aws/tests/unit_tests/chat_models/test_bedrock_converse.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/libs/aws/tests/unit_tests/chat_models/test_bedrock_converse.py b/libs/aws/tests/unit_tests/chat_models/test_bedrock_converse.py index a5f0d9e4..4f491cfb 100644 --- a/libs/aws/tests/unit_tests/chat_models/test_bedrock_converse.py +++ b/libs/aws/tests/unit_tests/chat_models/test_bedrock_converse.py @@ -1427,12 +1427,13 @@ def test_create_cache_point() -> None: def test_create_document() -> None: """Test creating a document.""" document = ChatBedrockConverse.create_document( - name="MyDoc", source={"text": "Cite me"}, enable_citations=True + name="MyDoc", source={"text": "Cite me"}, format="txt", enable_citations=True ) expected_doc = { "document": { "name": "MyDoc", "source": {"text": "Cite me"}, + "format": "txt", "citations": {"enabled": True}, } } From d648de1416a759536d561ce8285636880b7126f4 Mon Sep 17 00:00:00 2001 From: mathewHatch <41880891+mathewHatch@users.noreply.github.com> Date: Tue, 16 Sep 2025 12:19:42 -0500 Subject: [PATCH 09/13] tests: added integration test for use of ChatBedrockConverse.create_document usage --- .../chat_models/test_bedrock_converse.py | 48 +++++++++++++++++++ 1 file changed, 48 insertions(+) diff --git a/libs/aws/tests/integration_tests/chat_models/test_bedrock_converse.py b/libs/aws/tests/integration_tests/chat_models/test_bedrock_converse.py index 2977e629..91543f41 100644 --- a/libs/aws/tests/integration_tests/chat_models/test_bedrock_converse.py +++ b/libs/aws/tests/integration_tests/chat_models/test_bedrock_converse.py @@ -494,3 +494,51 @@ def test_bedrock_pdf_inputs() -> None: ] ) _ = model.invoke([message]) + + +def test_bedrock_document_usage() -> None: + model = ChatBedrockConverse( + model="us.anthropic.claude-3-5-sonnet-20241022-v2:0", region_name="us-west-2" + ) + + # Test bytes source typec + url = "https://www.w3.org/WAI/ER/tests/xhtml/testfiles/resources/pdf/dummy.pdf" + pdf_bytes = httpx.get(url).content + message = HumanMessage( + [ + {"type": "text", "text": "Summarize this document:"}, + ChatBedrockConverse.create_document( + "PDFDoc", source={"bytes": pdf_bytes}, format="pdf" + ), + ] + ) + + _ = model.invoke([message]) + + # Test text source type + text = "I am a text document." + message = HumanMessage( + [ + {"type": "text", "text": "Summarize this document:"}, + ChatBedrockConverse.create_document( + "TextDoc", source={"text": text}, format="txt" + ), + ] + ) + _ = model.invoke([message]) + + # Test content source type + split_text = [ + {"text": "I am the first part of a document."}, + {"text": "I am the second part."}, + {"text": "I am not sure how I got here."}, + ] + message = HumanMessage( + [ + {"type": "text", "text": "Summarize this document:"}, + ChatBedrockConverse.create_document( + "TextDoc", source={"content": split_text}, format="txt" + ), + ] + ) + _ = model.invoke([message]) From 1842cfcb450f3ab3091e14cdb0029d7d285048e1 Mon Sep 17 00:00:00 2001 From: mathewHatch <41880891+mathewHatch@users.noreply.github.com> Date: Tue, 16 Sep 2025 12:21:00 -0500 Subject: [PATCH 10/13] fix: parentheses now spelled correctly --- libs/aws/langchain_aws/chat_models/bedrock_converse.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/libs/aws/langchain_aws/chat_models/bedrock_converse.py b/libs/aws/langchain_aws/chat_models/bedrock_converse.py index 5ab980da..6d42b63d 100644 --- a/libs/aws/langchain_aws/chat_models/bedrock_converse.py +++ b/libs/aws/langchain_aws/chat_models/bedrock_converse.py @@ -535,7 +535,7 @@ def create_document( raise ValueError( "Name must be only alphanumeric characters," " whitespace characters (no more than one in a row)," - " hyphens, parantheses, or square brackets." + " hyphens, parentheses, or square brackets." ) valid_source_types = ["bytes", "content", "s3Location", "text"] From 42ce46006481d19166223c9f9b9b92c279a4628c Mon Sep 17 00:00:00 2001 From: mathewHatch <41880891+mathewHatch@users.noreply.github.com> Date: Tue, 16 Sep 2025 12:23:24 -0500 Subject: [PATCH 11/13] fix: typo in tests --- .../integration_tests/chat_models/test_bedrock_converse.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/libs/aws/tests/integration_tests/chat_models/test_bedrock_converse.py b/libs/aws/tests/integration_tests/chat_models/test_bedrock_converse.py index 91543f41..c80c56c6 100644 --- a/libs/aws/tests/integration_tests/chat_models/test_bedrock_converse.py +++ b/libs/aws/tests/integration_tests/chat_models/test_bedrock_converse.py @@ -501,7 +501,7 @@ def test_bedrock_document_usage() -> None: model="us.anthropic.claude-3-5-sonnet-20241022-v2:0", region_name="us-west-2" ) - # Test bytes source typec + # Test bytes source type url = "https://www.w3.org/WAI/ER/tests/xhtml/testfiles/resources/pdf/dummy.pdf" pdf_bytes = httpx.get(url).content message = HumanMessage( From d7578511b46df2ddddee2cc2dea4d0e084b5565c Mon Sep 17 00:00:00 2001 From: mathewHatch <41880891+mathewHatch@users.noreply.github.com> Date: Tue, 16 Sep 2025 12:31:47 -0500 Subject: [PATCH 12/13] fix: invalid regex search would raise value error when invalid characters not found using re.search instead of found --- libs/aws/langchain_aws/chat_models/bedrock_converse.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/libs/aws/langchain_aws/chat_models/bedrock_converse.py b/libs/aws/langchain_aws/chat_models/bedrock_converse.py index 6d42b63d..537a5b2d 100644 --- a/libs/aws/langchain_aws/chat_models/bedrock_converse.py +++ b/libs/aws/langchain_aws/chat_models/bedrock_converse.py @@ -531,7 +531,7 @@ def create_document( enable_citations: Whether to enable the Citations API for the document. Returns: Dictionary containing a properly formatted to add to message content.""" - if not re.search(r"[^A-Za-z0-9 \[\]()\-]|\s{2,}", name): + if re.search(r"[^A-Za-z0-9 \[\]()\-]|\s{2,}", name): raise ValueError( "Name must be only alphanumeric characters," " whitespace characters (no more than one in a row)," From 574c1499c1976adb0d7106350f87260ca43fde1a Mon Sep 17 00:00:00 2001 From: Matt Hatch Date: Thu, 25 Sep 2025 19:08:13 -0500 Subject: [PATCH 13/13] fix: added accidental deletions back in --- .../langchain_aws/chat_models/bedrock_converse.py | 13 +++++++++++++ 1 file changed, 13 insertions(+) diff --git a/libs/aws/langchain_aws/chat_models/bedrock_converse.py b/libs/aws/langchain_aws/chat_models/bedrock_converse.py index 9b9e89f0..899cee8b 100644 --- a/libs/aws/langchain_aws/chat_models/bedrock_converse.py +++ b/libs/aws/langchain_aws/chat_models/bedrock_converse.py @@ -721,6 +721,19 @@ def set_disable_streaming(cls, values: Dict) -> Any: def validate_environment(self) -> Self: """Validate that AWS credentials to and python package exists in environment.""" + # Skip creating new client if passed in constructor + if self.client is None: + self.client = create_aws_client( + region_name=self.region_name, + credentials_profile_name=self.credentials_profile_name, + aws_access_key_id=self.aws_access_key_id, + aws_secret_access_key=self.aws_secret_access_key, + aws_session_token=self.aws_session_token, + endpoint_url=self.endpoint_url, + config=self.config, + service_name="bedrock-runtime", + ) + # Create bedrock client for control plane API call if self.bedrock_client is None: bedrock_client_cfg = {}