diff --git a/aws_lambda_powertools/event_handler/__init__.py b/aws_lambda_powertools/event_handler/__init__.py index ffbb2abe4ae..c61cf4df1e1 100644 --- a/aws_lambda_powertools/event_handler/__init__.py +++ b/aws_lambda_powertools/event_handler/__init__.py @@ -11,7 +11,7 @@ Response, ) from aws_lambda_powertools.event_handler.appsync import AppSyncResolver -from aws_lambda_powertools.event_handler.bedrock_agent import BedrockAgentResolver +from aws_lambda_powertools.event_handler.bedrock_agent import BedrockAgentResolver, BedrockResponse from aws_lambda_powertools.event_handler.lambda_function_url import ( LambdaFunctionUrlResolver, ) @@ -24,6 +24,7 @@ "ALBResolver", "ApiGatewayResolver", "BedrockAgentResolver", + "BedrockResponse", "CORSConfig", "LambdaFunctionUrlResolver", "Response", diff --git a/aws_lambda_powertools/event_handler/bedrock_agent.py b/aws_lambda_powertools/event_handler/bedrock_agent.py index e4f41bd38db..dfdbae907a4 100644 --- a/aws_lambda_powertools/event_handler/bedrock_agent.py +++ b/aws_lambda_powertools/event_handler/bedrock_agent.py @@ -22,6 +22,29 @@ from aws_lambda_powertools.utilities.data_classes import BedrockAgentEvent +class BedrockResponse: + """ + Contains the response body, status code, content type, and optional attributes + for session management and knowledge base configuration. + """ + + def __init__( + self, + body: Any = None, + status_code: int = 200, + content_type: str = "application/json", + session_attributes: dict[str, Any] | None = None, + prompt_session_attributes: dict[str, Any] | None = None, + knowledge_bases_configuration: list[dict[str, Any]] | None = None, + ) -> None: + self.body = body + self.status_code = status_code + self.content_type = content_type + self.session_attributes = session_attributes + self.prompt_session_attributes = prompt_session_attributes + self.knowledge_bases_configuration = knowledge_bases_configuration + + class BedrockResponseBuilder(ResponseBuilder): """ Bedrock Response Builder. This builds the response dict to be returned by Lambda when using Bedrock Agents. @@ -34,11 +57,17 @@ def build(self, event: BedrockAgentEvent, *args) -> dict[str, Any]: """Build the full response dict to be returned by the lambda""" self._route(event, None) - body = self.response.body - if self.response.is_json() and not isinstance(self.response.body, str): - body = self.serializer(self.response.body) + bedrock_response = None + if isinstance(self.response.body, dict) and "body" in self.response.body: + bedrock_response = BedrockResponse(**self.response.body) + body = bedrock_response.body + else: + body = self.response.body + + if self.response.is_json() and not isinstance(body, str): + body = self.serializer(body) - return { + response = { "messageVersion": "1.0", "response": { "actionGroup": event.action_group, @@ -53,6 +82,17 @@ def build(self, event: BedrockAgentEvent, *args) -> dict[str, Any]: }, } + # Add Bedrock-specific attributes + if bedrock_response: + if bedrock_response.session_attributes: + response["sessionAttributes"] = bedrock_response.session_attributes + if bedrock_response.prompt_session_attributes: + response["promptSessionAttributes"] = bedrock_response.prompt_session_attributes + if bedrock_response.knowledge_bases_configuration: + response["knowledgeBasesConfiguration"] = bedrock_response.knowledge_bases_configuration # type: ignore + + return response + class BedrockAgentResolver(ApiGatewayResolver): """Bedrock Agent Resolver diff --git a/docs/core/event_handler/bedrock_agents.md b/docs/core/event_handler/bedrock_agents.md index 6c514012f72..612b9a4ceb7 100644 --- a/docs/core/event_handler/bedrock_agents.md +++ b/docs/core/event_handler/bedrock_agents.md @@ -313,6 +313,19 @@ To implement these customizations, include extra parameters when defining your r --8<-- "examples/event_handler_bedrock_agents/src/customizing_bedrock_api_operations.py" ``` +### Fine grained responses + +`BedrockResponse` class that provides full control over Bedrock Agent responses. + +You can use this class to add additional fields as needed, such as [session attributes, prompt session attributes, and knowledge base configurations](https://docs.aws.amazon.com/bedrock/latest/userguide/agents-lambda.html#agents-lambda-response). + +???+ info "Note" + The default response only includes the essential fields to keep the payload size minimal, as AWS Lambda has a maximum response size of 25 KB. + +```python title="bedrockresponse.py" title="Customzing your Bedrock Response" +--8<-- "examples/event_handler_bedrock_agents/src/bedrockresponse.py" +``` + ## Testing your code Test your routes by passing an [Agent for Amazon Bedrock proxy event](https://docs.aws.amazon.com/bedrock/latest/userguide/agents-lambda.html#agents-lambda-input) request: diff --git a/examples/event_handler_bedrock_agents/src/bedrockresponse.py b/examples/event_handler_bedrock_agents/src/bedrockresponse.py new file mode 100644 index 00000000000..51ead4ee05e --- /dev/null +++ b/examples/event_handler_bedrock_agents/src/bedrockresponse.py @@ -0,0 +1,18 @@ +from http import HTTPStatus + +from aws_lambda_powertools.event_handler import BedrockResponse + +response = BedrockResponse( + status_code=HTTPStatus.OK.value, + body={"message": "Hello from Bedrock!"}, + session_attributes={"user_id": "123"}, + prompt_session_attributes={"context": "testing"}, + knowledge_bases_configuration=[ + { + "knowledgeBaseId": "kb-123", + "retrievalConfiguration": { + "vectorSearchConfiguration": {"numberOfResults": 3, "overrideSearchType": "HYBRID"}, + }, + }, + ], +) diff --git a/tests/functional/event_handler/_pydantic/test_bedrock_agent.py b/tests/functional/event_handler/_pydantic/test_bedrock_agent.py index 6dcc55c2da5..24430b8c441 100644 --- a/tests/functional/event_handler/_pydantic/test_bedrock_agent.py +++ b/tests/functional/event_handler/_pydantic/test_bedrock_agent.py @@ -4,7 +4,7 @@ import pytest from typing_extensions import Annotated -from aws_lambda_powertools.event_handler import BedrockAgentResolver, Response, content_types +from aws_lambda_powertools.event_handler import BedrockAgentResolver, BedrockResponse, Response, content_types from aws_lambda_powertools.event_handler.openapi.params import Body from aws_lambda_powertools.utilities.data_classes import BedrockAgentEvent from tests.functional.utils import load_event @@ -200,3 +200,114 @@ def handler() -> Optional[Dict]: # THEN the schema must be a valid 3.0.3 version assert openapi30_schema(schema) assert schema.get("openapi") == "3.0.3" + + +def test_bedrock_agent_with_bedrock_response(): + # GIVEN a Bedrock Agent event + app = BedrockAgentResolver() + + # WHEN using BedrockResponse + @app.get("/claims", description="Gets claims") + def claims(): + assert isinstance(app.current_event, BedrockAgentEvent) + assert app.lambda_context == {} + return BedrockResponse( + session_attributes={"user_id": "123"}, + prompt_session_attributes={"context": "testing"}, + knowledge_bases_configuration=[ + { + "knowledgeBaseId": "kb-123", + "retrievalConfiguration": { + "vectorSearchConfiguration": {"numberOfResults": 3, "overrideSearchType": "HYBRID"}, + }, + }, + ], + ) + + result = app(load_event("bedrockAgentEvent.json"), {}) + + assert result["messageVersion"] == "1.0" + assert result["response"]["apiPath"] == "/claims" + assert result["response"]["actionGroup"] == "ClaimManagementActionGroup" + assert result["response"]["httpMethod"] == "GET" + assert result["sessionAttributes"] == {"user_id": "123"} + assert result["promptSessionAttributes"] == {"context": "testing"} + assert result["knowledgeBasesConfiguration"] == [ + { + "knowledgeBaseId": "kb-123", + "retrievalConfiguration": { + "vectorSearchConfiguration": {"numberOfResults": 3, "overrideSearchType": "HYBRID"}, + }, + }, + ] + + +def test_bedrock_agent_with_empty_bedrock_response(): + # GIVEN a Bedrock Agent event + app = BedrockAgentResolver() + + @app.get("/claims", description="Gets claims") + def claims(): + return BedrockResponse(body={"message": "test"}) + + # WHEN calling the event handler + result = app(load_event("bedrockAgentEvent.json"), {}) + + # THEN process event correctly without optional attributes + assert result["messageVersion"] == "1.0" + assert result["response"]["httpStatusCode"] == 200 + assert "sessionAttributes" not in result + assert "promptSessionAttributes" not in result + assert "knowledgeBasesConfiguration" not in result + + +def test_bedrock_agent_with_partial_bedrock_response(): + # GIVEN a Bedrock Agent event + app = BedrockAgentResolver() + + @app.get("/claims", description="Gets claims") + def claims(): + return BedrockResponse( + body={"message": "test"}, + session_attributes={"user_id": "123"}, + # Only include session_attributes to test partial response + ) + + # WHEN calling the event handler + result = app(load_event("bedrockAgentEvent.json"), {}) + + # THEN process event correctly with only session_attributes + assert result["messageVersion"] == "1.0" + assert result["response"]["httpStatusCode"] == 200 + assert result["sessionAttributes"] == {"user_id": "123"} + assert "promptSessionAttributes" not in result + assert "knowledgeBasesConfiguration" not in result + + +def test_bedrock_agent_with_different_attributes_combination(): + # GIVEN a Bedrock Agent event + app = BedrockAgentResolver() + + @app.get("/claims", description="Gets claims") + def claims(): + return BedrockResponse( + body={"message": "test"}, + prompt_session_attributes={"context": "testing"}, + knowledge_bases_configuration=[ + { + "knowledgeBaseId": "kb-123", + "retrievalConfiguration": {"vectorSearchConfiguration": {"numberOfResults": 3}}, + }, + ], + # Omit session_attributes to test different combination + ) + + # WHEN calling the event handler + result = app(load_event("bedrockAgentEvent.json"), {}) + + # THEN process event correctly with specific attributes + assert result["messageVersion"] == "1.0" + assert result["response"]["httpStatusCode"] == 200 + assert "sessionAttributes" not in result + assert result["promptSessionAttributes"] == {"context": "testing"} + assert result["knowledgeBasesConfiguration"][0]["knowledgeBaseId"] == "kb-123"