diff --git a/ariadne/graphql.py b/ariadne/graphql.py index c17f885a..08b24345 100644 --- a/ariadne/graphql.py +++ b/ariadne/graphql.py @@ -9,6 +9,7 @@ List, Optional, Sequence, + Tuple, Type, cast, Union, @@ -39,6 +40,7 @@ ErrorFormatter, ExtensionList, GraphQLResult, + GraphQLResultUpdate, MiddlewareList, QueryParser, QueryValidator, @@ -146,6 +148,8 @@ async def graphql( `**kwargs`: any kwargs not used by `graphql` are passed to `graphql.graphql`. """ + result_update: Optional[GraphQLResultUpdate] = None + extension_manager = ExtensionManager(extensions, context_value) with extension_manager.request(): @@ -200,6 +204,10 @@ async def graphql( if isawaitable(root_value): root_value = await root_value + if isinstance(root_value, GraphQLResultUpdate): + result_update = root_value + root_value = root_value.root_value + result = execute( schema, document, @@ -217,7 +225,7 @@ async def graphql( if isawaitable(result): result = await cast(Awaitable[ExecutionResult], result) except GraphQLError as error: - return handle_graphql_errors( + result = handle_graphql_errors( [error], logger=logger, error_formatter=error_formatter, @@ -225,7 +233,12 @@ async def graphql( extension_manager=extension_manager, ) - return handle_query_result( + if result_update: + return result_update.update_result(result) + + return result + + result = handle_query_result( result, logger=logger, error_formatter=error_formatter, @@ -233,6 +246,11 @@ async def graphql( extension_manager=extension_manager, ) + if result_update: + return result_update.update_result(result) + + return result + def graphql_sync( schema: GraphQLSchema, @@ -321,6 +339,8 @@ def graphql_sync( `**kwargs`: any kwargs not used by `graphql_sync` are passed to `graphql.graphql_sync`. """ + result_update: Optional[GraphQLResultUpdate] = None + extension_manager = ExtensionManager(extensions, context_value) with extension_manager.request(): @@ -379,6 +399,10 @@ def graphql_sync( "in synchronous query executor." ) + if isinstance(root_value, GraphQLResultUpdate): + result_update = root_value + root_value = root_value.root_value + result = execute_sync( schema, document, @@ -399,7 +423,7 @@ def graphql_sync( "GraphQL execution failed to complete synchronously." ) except GraphQLError as error: - return handle_graphql_errors( + result = handle_graphql_errors( [error], logger=logger, error_formatter=error_formatter, @@ -407,7 +431,12 @@ def graphql_sync( extension_manager=extension_manager, ) - return handle_query_result( + if result_update: + return result_update.update_result(result) + + return result + + result = handle_query_result( result, logger=logger, error_formatter=error_formatter, @@ -415,6 +444,11 @@ def graphql_sync( extension_manager=extension_manager, ) + if result_update: + return result_update.update_result(result) + + return result + async def subscribe( schema: GraphQLSchema, diff --git a/ariadne/types.py b/ariadne/types.py index 3dc21f01..dd5e3669 100644 --- a/ariadne/types.py +++ b/ariadne/types.py @@ -29,6 +29,7 @@ __all__ = [ "Resolver", "GraphQLResult", + "GraphQLResultUpdate", "SubscriptionResult", "Subscriber", "ErrorFormatter", @@ -228,6 +229,35 @@ async def get_context_value(request: Request, _): Callable[[Optional[Any], Optional[str], Optional[dict], DocumentNode], Any], ] + +class GraphQLResultUpdate: + """A `RootValue` wrapper that includes result JSON update logic. + + Can be returned by the `RootValue` callable. Not used by Ariadne directly + but part of the support for Ariadne GraphQL Proxy. + + # Attributes + + - `root_value: Optional[dict]`: `RootValue` to use during query execution. + """ + + __slots__ = ("root_value",) + + root_value: Optional[dict] + + def __init__(self, root_value: Optional[dict] = None): + self.root_value = root_value + + def update_result(self, result: GraphQLResult) -> GraphQLResult: + """An update function used to create a final `GraphQL` result tuple to + create a JSON response from. + + Default implementation in `GraphQLResultUpdate` is a passthrough that + returns `result` value without any changes. + """ + return result + + """Type of `query_parser` option of GraphQL servers. Enables customization of server's GraphQL parsing logic. If not set or `None`, diff --git a/tests/conftest.py b/tests/conftest.py index a48173b7..9333a782 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -23,6 +23,7 @@ def type_defs(): testContext: String testRoot: String testError: Boolean + context: String } type Mutation { diff --git a/tests/test_graphql.py b/tests/test_graphql.py index 2fd4ea4a..1789ed05 100644 --- a/tests/test_graphql.py +++ b/tests/test_graphql.py @@ -3,6 +3,7 @@ from graphql.validation.rules import ValidationRule from ariadne import graphql, graphql_sync, subscribe +from ariadne.types import GraphQLResultUpdate class AlwaysInvalid(ValidationRule): @@ -12,6 +13,12 @@ def leave_operation_definition( # pylint: disable=unused-argument self.context.report_error(GraphQLError("Invalid")) +class CustomGraphQLResultUpdate(GraphQLResultUpdate): + def update_result(self, result): + success, data = result + return success, dict(**data, updated=True) + + def test_graphql_sync_executes_the_query(schema): success, result = graphql_sync(schema, {"query": '{ hello(name: "world") }'}) assert success @@ -51,8 +58,21 @@ def test_graphql_sync_prevents_introspection_query_when_option_is_disabled(schem ) +def test_graphql_sync_executes_the_query_using_result_update_obj(schema): + success, result = graphql_sync( + schema, + {"query": '{ context }'}, + root_value=CustomGraphQLResultUpdate({"context": "Works!"}), + ) + assert success + assert result == { + "data": {"context": "Works!"}, + "updated": True, + } + + @pytest.mark.asyncio -async def test_graphql_execute_the_query(schema): +async def test_graphql_executes_the_query(schema): success, result = await graphql(schema, {"query": '{ hello(name: "world") }'}) assert success assert result["data"] == {"hello": "Hello, world!"} @@ -94,6 +114,20 @@ async def test_graphql_prevents_introspection_query_when_option_is_disabled(sche ) +@pytest.mark.asyncio +async def test_graphql_executes_the_query_using_result_update_obj(schema): + success, result = await graphql( + schema, + {"query": '{ context }'}, + root_value=CustomGraphQLResultUpdate({"context": "Works!"}), + ) + assert success + assert result == { + "data": {"context": "Works!"}, + "updated": True, + } + + @pytest.mark.asyncio async def test_subscription_returns_an_async_iterator(schema): success, result = await subscribe(schema, {"query": "subscription { ping }"})