diff --git a/aws_lambda_powertools/event_handler/__init__.py b/aws_lambda_powertools/event_handler/__init__.py index 8bcf2d6636c..89952580dcc 100644 --- a/aws_lambda_powertools/event_handler/__init__.py +++ b/aws_lambda_powertools/event_handler/__init__.py @@ -11,7 +11,7 @@ Response, ) from aws_lambda_powertools.event_handler.appsync import AppSyncResolver -from aws_lambda_powertools.event_handler.bedrock_agent import BedrockAgentResolver +from aws_lambda_powertools.event_handler.bedrock_agent import BedrockAgentResolver, BedrockResponse from aws_lambda_powertools.event_handler.events_appsync.appsync_events import AppSyncEventsResolver from aws_lambda_powertools.event_handler.lambda_function_url import ( LambdaFunctionUrlResolver, @@ -26,6 +26,7 @@ "ALBResolver", "ApiGatewayResolver", "BedrockAgentResolver", + "BedrockResponse", "CORSConfig", "LambdaFunctionUrlResolver", "Response", diff --git a/aws_lambda_powertools/event_handler/api_gateway.py b/aws_lambda_powertools/event_handler/api_gateway.py index a5c5a7bb053..f1f38b399a9 100644 --- a/aws_lambda_powertools/event_handler/api_gateway.py +++ b/aws_lambda_powertools/event_handler/api_gateway.py @@ -73,6 +73,7 @@ _DEFAULT_OPENAPI_RESPONSE_DESCRIPTION = "Successful Response" _ROUTE_REGEX = "^{}$" _JSON_DUMP_CALL = partial(json.dumps, separators=(",", ":"), cls=Encoder) +_DEFAULT_CONTENT_TYPE = "application/json" ResponseEventT = TypeVar("ResponseEventT", bound=BaseProxyEvent) ResponseT = TypeVar("ResponseT") @@ -255,6 +256,35 @@ def build_allow_methods(methods: set[str]) -> str: return ",".join(sorted(methods)) +class BedrockResponse(Generic[ResponseT]): + """ + Contains the response body, status code, content type, and optional attributes + for session management and knowledge base configuration. + """ + + def __init__( + self, + body: Any = None, + status_code: int = 200, + content_type: str = _DEFAULT_CONTENT_TYPE, + session_attributes: dict[str, Any] | None = None, + prompt_session_attributes: dict[str, Any] | None = None, + knowledge_bases_configuration: list[dict[str, Any]] | None = None, + ) -> None: + self.body = body + self.status_code = status_code + self.content_type = content_type + self.session_attributes = session_attributes + self.prompt_session_attributes = prompt_session_attributes + self.knowledge_bases_configuration = knowledge_bases_configuration + + def is_json(self) -> bool: + """ + Returns True if the response is JSON, based on the Content-Type. + """ + return True + + class Response(Generic[ResponseT]): """Response data class that provides greater control over what is returned from the proxy event""" @@ -300,7 +330,7 @@ def is_json(self) -> bool: content_type = self.headers.get("Content-Type", "") if isinstance(content_type, list): content_type = content_type[0] - return content_type.startswith("application/json") + return content_type.startswith(_DEFAULT_CONTENT_TYPE) class Route: @@ -572,7 +602,7 @@ def _get_openapi_path( operation_responses: dict[int, OpenAPIResponse] = { 422: { "description": "Validation Error", - "content": {"application/json": {"schema": {"$ref": f"{COMPONENT_REF_PREFIX}HTTPValidationError"}}}, + "content": {_DEFAULT_CONTENT_TYPE: {"schema": {"$ref": f"{COMPONENT_REF_PREFIX}HTTPValidationError"}}}, }, } @@ -581,7 +611,9 @@ def _get_openapi_path( http_code = self.custom_response_validation_http_code.value operation_responses[http_code] = { "description": "Response Validation Error", - "content": {"application/json": {"schema": {"$ref": f"{COMPONENT_REF_PREFIX}ResponseValidationError"}}}, + "content": { + _DEFAULT_CONTENT_TYPE: {"schema": {"$ref": f"{COMPONENT_REF_PREFIX}ResponseValidationError"}}, + }, } # Add model definition definitions["ResponseValidationError"] = response_validation_error_response_definition @@ -594,7 +626,7 @@ def _get_openapi_path( # Case 1: there is not 'content' key if "content" not in response: response["content"] = { - "application/json": self._openapi_operation_return( + _DEFAULT_CONTENT_TYPE: self._openapi_operation_return( param=dependant.return_param, model_name_map=model_name_map, field_mapping=field_mapping, @@ -645,7 +677,7 @@ def _get_openapi_path( # Add the response schema to the OpenAPI 200 response operation_responses[200] = { "description": self.response_description or _DEFAULT_OPENAPI_RESPONSE_DESCRIPTION, - "content": {"application/json": response_schema}, + "content": {_DEFAULT_CONTENT_TYPE: response_schema}, } operation["responses"] = operation_responses @@ -1474,7 +1506,10 @@ def __call__(self, app: ApiGatewayResolver) -> dict | tuple | Response: return self.current_middleware(app, self.next_middleware) -def _registered_api_adapter(app: ApiGatewayResolver, next_middleware: Callable[..., Any]) -> dict | tuple | Response: +def _registered_api_adapter( + app: ApiGatewayResolver, + next_middleware: Callable[..., Any], +) -> dict | tuple | Response | BedrockResponse: """ Calls the registered API using the "_route_args" from the Resolver context to ensure the last call in the chain will match the API route function signature and ensure that Powertools passes the API @@ -1632,7 +1667,7 @@ def _add_resolver_response_validation_error_response_to_route( response_validation_error_response = { "description": "Response Validation Error", "content": { - "application/json": { + _DEFAULT_CONTENT_TYPE: { "schema": {"$ref": f"{COMPONENT_REF_PREFIX}ResponseValidationError"}, }, }, @@ -2151,7 +2186,7 @@ def swagger_handler(): if query_params.get("format") == "json": return Response( status_code=200, - content_type="application/json", + content_type=_DEFAULT_CONTENT_TYPE, body=escaped_spec, ) @@ -2538,7 +2573,7 @@ def _call_route(self, route: Route, route_arguments: dict[str, str]) -> Response self._reset_processed_stack() return self._response_builder_class( - response=self._to_response( + response=self._to_response( # type: ignore[arg-type] route(router_middlewares=self._router_middlewares, app=self, route_arguments=route_arguments), ), serializer=self._serializer, @@ -2627,7 +2662,7 @@ def _call_exception_handler(self, exp: Exception, route: Route) -> ResponseBuild return None - def _to_response(self, result: dict | tuple | Response) -> Response: + def _to_response(self, result: dict | tuple | Response | BedrockResponse) -> Response | BedrockResponse: """Convert the route's result to a Response 3 main result types are supported: @@ -2638,7 +2673,7 @@ def _to_response(self, result: dict | tuple | Response) -> Response: - Response: returned as is, and allows for more flexibility """ status_code = HTTPStatus.OK - if isinstance(result, Response): + if isinstance(result, (Response, BedrockResponse)): return result elif isinstance(result, tuple) and len(result) == 2: # Unpack result dict and status code from tuple @@ -2971,8 +3006,9 @@ def _get_base_path(self) -> str: # ALB doesn't have a stage variable, so we just return an empty string return "" + # BedrockResponse is not used here but adding the same signature to keep strong typing @override - def _to_response(self, result: dict | tuple | Response) -> Response: + def _to_response(self, result: dict | tuple | Response | BedrockResponse) -> Response | BedrockResponse: """Convert the route's result to a Response ALB requires a non-null body otherwise it converts as HTTP 5xx diff --git a/aws_lambda_powertools/event_handler/bedrock_agent.py b/aws_lambda_powertools/event_handler/bedrock_agent.py index 221163b1ae4..c3b48bcb95e 100644 --- a/aws_lambda_powertools/event_handler/bedrock_agent.py +++ b/aws_lambda_powertools/event_handler/bedrock_agent.py @@ -8,6 +8,7 @@ from aws_lambda_powertools.event_handler import ApiGatewayResolver from aws_lambda_powertools.event_handler.api_gateway import ( _DEFAULT_OPENAPI_RESPONSE_DESCRIPTION, + BedrockResponse, ProxyEventType, ResponseBuilder, ) @@ -32,14 +33,11 @@ class BedrockResponseBuilder(ResponseBuilder): @override def build(self, event: BedrockAgentEvent, *args) -> dict[str, Any]: - """Build the full response dict to be returned by the lambda""" - self._route(event, None) - body = self.response.body if self.response.is_json() and not isinstance(self.response.body, str): body = self.serializer(self.response.body) - return { + response = { "messageVersion": "1.0", "response": { "actionGroup": event.action_group, @@ -54,6 +52,19 @@ def build(self, event: BedrockAgentEvent, *args) -> dict[str, Any]: }, } + # Add Bedrock-specific attributes + if isinstance(self.response, BedrockResponse): + if self.response.session_attributes: + response["sessionAttributes"] = self.response.session_attributes + + if self.response.prompt_session_attributes: + response["promptSessionAttributes"] = self.response.prompt_session_attributes + + if self.response.knowledge_bases_configuration: + response["knowledgeBasesConfiguration"] = self.response.knowledge_bases_configuration + + return response + class BedrockAgentResolver(ApiGatewayResolver): """Bedrock Agent Resolver diff --git a/docs/core/event_handler/bedrock_agents.md b/docs/core/event_handler/bedrock_agents.md index 9665628ff30..b7626f32f97 100644 --- a/docs/core/event_handler/bedrock_agents.md +++ b/docs/core/event_handler/bedrock_agents.md @@ -323,6 +323,17 @@ You can enable user confirmation with Bedrock Agents to have your application as 1. Add an openapi extension +### Fine grained responses + +???+ info "Note" + The default response only includes the essential fields to keep the payload size minimal, as AWS Lambda has a maximum response size of 25 KB. + +You can use `BedrockResponse` class to add additional fields as needed, such as [session attributes, prompt session attributes, and knowledge base configurations](https://docs.aws.amazon.com/bedrock/latest/userguide/agents-lambda.html#agents-lambda-response){target="_blank"}. + +```python title="working_with_bedrockresponse.py" title="Customzing your Bedrock Response" hl_lines="5 16" +--8<-- "examples/event_handler_bedrock_agents/src/working_with_bedrockresponse.py" +``` + ## Testing your code Test your routes by passing an [Agent for Amazon Bedrock proxy event](https://docs.aws.amazon.com/bedrock/latest/userguide/agents-lambda.html#agents-lambda-input) request: diff --git a/examples/event_handler_bedrock_agents/src/working_with_bedrockresponse.py b/examples/event_handler_bedrock_agents/src/working_with_bedrockresponse.py new file mode 100644 index 00000000000..25e2a56eee1 --- /dev/null +++ b/examples/event_handler_bedrock_agents/src/working_with_bedrockresponse.py @@ -0,0 +1,35 @@ +from http import HTTPStatus + +from aws_lambda_powertools import Logger, Tracer +from aws_lambda_powertools.event_handler import BedrockAgentResolver +from aws_lambda_powertools.event_handler.api_gateway import BedrockResponse +from aws_lambda_powertools.utilities.typing import LambdaContext + +tracer = Tracer() +logger = Logger() +app = BedrockAgentResolver() + + +@app.get("/return_with_session", description="Returns a hello world with session attributes") +@tracer.capture_method +def hello_world(): + return BedrockResponse( + status_code=HTTPStatus.OK.value, + body={"message": "Hello from Bedrock!"}, + session_attributes={"user_id": "123"}, + prompt_session_attributes={"context": "testing"}, + knowledge_bases_configuration=[ + { + "knowledgeBaseId": "kb-123", + "retrievalConfiguration": { + "vectorSearchConfiguration": {"numberOfResults": 3, "overrideSearchType": "HYBRID"}, + }, + }, + ], + ) + + +@logger.inject_lambda_context +@tracer.capture_lambda_handler +def lambda_handler(event: dict, context: LambdaContext): + return app.resolve(event, context) diff --git a/tests/functional/event_handler/_pydantic/test_bedrock_agent.py b/tests/functional/event_handler/_pydantic/test_bedrock_agent.py index eaa666b3d36..fff0f8b7d42 100644 --- a/tests/functional/event_handler/_pydantic/test_bedrock_agent.py +++ b/tests/functional/event_handler/_pydantic/test_bedrock_agent.py @@ -4,7 +4,7 @@ import pytest from typing_extensions import Annotated -from aws_lambda_powertools.event_handler import BedrockAgentResolver, Response, content_types +from aws_lambda_powertools.event_handler import BedrockAgentResolver, BedrockResponse, Response, content_types from aws_lambda_powertools.event_handler.openapi.params import Body from aws_lambda_powertools.utilities.data_classes import BedrockAgentEvent from tests.functional.utils import load_event @@ -202,6 +202,133 @@ def handler() -> Optional[Dict]: assert schema.get("openapi") == "3.0.3" +def test_bedrock_agent_with_bedrock_response(): + # GIVEN a Bedrock Agent event + app = BedrockAgentResolver() + + # WHEN using BedrockResponse + @app.get("/claims", description="Gets claims") + def claims(): + assert isinstance(app.current_event, BedrockAgentEvent) + assert app.lambda_context == {} + return BedrockResponse( + session_attributes={"user_id": "123"}, + prompt_session_attributes={"context": "testing"}, + knowledge_bases_configuration=[ + { + "knowledgeBaseId": "kb-123", + "retrievalConfiguration": { + "vectorSearchConfiguration": {"numberOfResults": 3, "overrideSearchType": "HYBRID"}, + }, + }, + ], + ) + + result = app(load_event("bedrockAgentEvent.json"), {}) + + assert result["messageVersion"] == "1.0" + assert result["response"]["apiPath"] == "/claims" + assert result["response"]["actionGroup"] == "ClaimManagementActionGroup" + assert result["response"]["httpMethod"] == "GET" + assert result["sessionAttributes"] == {"user_id": "123"} + assert result["promptSessionAttributes"] == {"context": "testing"} + assert result["knowledgeBasesConfiguration"] == [ + { + "knowledgeBaseId": "kb-123", + "retrievalConfiguration": { + "vectorSearchConfiguration": {"numberOfResults": 3, "overrideSearchType": "HYBRID"}, + }, + }, + ] + + +def test_bedrock_agent_with_empty_bedrock_response(): + # GIVEN a Bedrock Agent event + app = BedrockAgentResolver() + + @app.get("/claims", description="Gets claims") + def claims(): + return BedrockResponse(body={"message": "test"}) + + # WHEN calling the event handler + result = app(load_event("bedrockAgentEvent.json"), {}) + + # THEN process event correctly without optional attributes + assert result["messageVersion"] == "1.0" + assert result["response"]["httpStatusCode"] == 200 + assert "sessionAttributes" not in result + assert "promptSessionAttributes" not in result + assert "knowledgeBasesConfiguration" not in result + + +def test_bedrock_agent_with_partial_bedrock_response(): + # GIVEN a Bedrock Agent event + app = BedrockAgentResolver() + + @app.get("/claims", description="Gets claims") + def claims() -> Dict[str, Any]: + return BedrockResponse( + body={"message": "test"}, + session_attributes={"user_id": "123"}, + # Only include session_attributes to test partial response + ) + + # WHEN calling the event handler + result = app(load_event("bedrockAgentEvent.json"), {}) + + # THEN process event correctly with only session_attributes + assert result["messageVersion"] == "1.0" + assert result["response"]["httpStatusCode"] == 200 + assert result["sessionAttributes"] == {"user_id": "123"} + assert "promptSessionAttributes" not in result + assert "knowledgeBasesConfiguration" not in result + + +def test_bedrock_agent_with_string(): + # GIVEN a Bedrock Agent event + app = BedrockAgentResolver() + + @app.get("/claims", description="Gets claims") + def claims() -> str: + return "a" + + # WHEN calling the event handler + result = app(load_event("bedrockAgentEvent.json"), {}) + + # THEN process event correctly with only session_attributes + assert result["messageVersion"] == "1.0" + assert result["response"]["httpStatusCode"] == 200 + + +def test_bedrock_agent_with_different_attributes_combination(): + # GIVEN a Bedrock Agent event + app = BedrockAgentResolver() + + @app.get("/claims", description="Gets claims") + def claims() -> Dict[str, Any]: + return BedrockResponse( + body={"message": "test"}, + prompt_session_attributes={"context": "testing"}, + knowledge_bases_configuration=[ + { + "knowledgeBaseId": "kb-123", + "retrievalConfiguration": {"vectorSearchConfiguration": {"numberOfResults": 3}}, + }, + ], + # Omit session_attributes to test different combination + ) + + # WHEN calling the event handler + result = app(load_event("bedrockAgentEvent.json"), {}) + + # THEN process event correctly with specific attributes + assert result["messageVersion"] == "1.0" + assert result["response"]["httpStatusCode"] == 200 + assert "sessionAttributes" not in result + assert result["promptSessionAttributes"] == {"context": "testing"} + assert result["knowledgeBasesConfiguration"][0]["knowledgeBaseId"] == "kb-123" + + def test_bedrock_resolver_with_openapi_extensions(): # GIVEN BedrockAgentResolver is initialized with enable_validation=True app = BedrockAgentResolver(enable_validation=True)