Skip to content

Commit

Permalink
Add hack for updating final result with data from root value
Browse files Browse the repository at this point in the history
  • Loading branch information
rafalp committed Feb 28, 2024
1 parent f50f140 commit f7c6015
Show file tree
Hide file tree
Showing 4 changed files with 104 additions and 5 deletions.
42 changes: 38 additions & 4 deletions ariadne/graphql.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
List,
Optional,
Sequence,
Tuple,
Type,
cast,
Union,
Expand Down Expand Up @@ -39,6 +40,7 @@
ErrorFormatter,
ExtensionList,
GraphQLResult,
GraphQLResultUpdate,
MiddlewareList,
QueryParser,
QueryValidator,
Expand Down Expand Up @@ -146,6 +148,8 @@ async def graphql(
`**kwargs`: any kwargs not used by `graphql` are passed to
`graphql.graphql`.
"""
result_update: Optional[GraphQLResultUpdate] = None

extension_manager = ExtensionManager(extensions, context_value)

with extension_manager.request():
Expand Down Expand Up @@ -200,6 +204,10 @@ async def graphql(
if isawaitable(root_value):
root_value = await root_value

if isinstance(root_value, GraphQLResultUpdate):
result_update = root_value
root_value = root_value.root_value

result = execute(
schema,
document,
Expand All @@ -217,22 +225,32 @@ async def graphql(
if isawaitable(result):
result = await cast(Awaitable[ExecutionResult], result)
except GraphQLError as error:
return handle_graphql_errors(
result = handle_graphql_errors(
[error],
logger=logger,
error_formatter=error_formatter,
debug=debug,
extension_manager=extension_manager,
)

return handle_query_result(
if result_update:
return result_update.update_result(result)

Check warning on line 237 in ariadne/graphql.py

View check run for this annotation

Codecov / codecov/patch

ariadne/graphql.py#L237

Added line #L237 was not covered by tests

return result

result = handle_query_result(
result,
logger=logger,
error_formatter=error_formatter,
debug=debug,
extension_manager=extension_manager,
)

if result_update:
return result_update.update_result(result)

return result


def graphql_sync(
schema: GraphQLSchema,
Expand Down Expand Up @@ -321,6 +339,8 @@ def graphql_sync(
`**kwargs`: any kwargs not used by `graphql_sync` are passed to
`graphql.graphql_sync`.
"""
result_update: Optional[GraphQLResultUpdate] = None

extension_manager = ExtensionManager(extensions, context_value)

with extension_manager.request():
Expand Down Expand Up @@ -379,6 +399,10 @@ def graphql_sync(
"in synchronous query executor."
)

if isinstance(root_value, GraphQLResultUpdate):
result_update = root_value
root_value = root_value.root_value

result = execute_sync(
schema,
document,
Expand All @@ -399,22 +423,32 @@ def graphql_sync(
"GraphQL execution failed to complete synchronously."
)
except GraphQLError as error:
return handle_graphql_errors(
result = handle_graphql_errors(
[error],
logger=logger,
error_formatter=error_formatter,
debug=debug,
extension_manager=extension_manager,
)

return handle_query_result(
if result_update:
return result_update.update_result(result)

Check warning on line 435 in ariadne/graphql.py

View check run for this annotation

Codecov / codecov/patch

ariadne/graphql.py#L435

Added line #L435 was not covered by tests

return result

result = handle_query_result(
result,
logger=logger,
error_formatter=error_formatter,
debug=debug,
extension_manager=extension_manager,
)

if result_update:
return result_update.update_result(result)

return result


async def subscribe(
schema: GraphQLSchema,
Expand Down
30 changes: 30 additions & 0 deletions ariadne/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
__all__ = [
"Resolver",
"GraphQLResult",
"GraphQLResultUpdate",
"SubscriptionResult",
"Subscriber",
"ErrorFormatter",
Expand Down Expand Up @@ -228,6 +229,35 @@ async def get_context_value(request: Request, _):
Callable[[Optional[Any], Optional[str], Optional[dict], DocumentNode], Any],
]


class GraphQLResultUpdate:
"""A `RootValue` wrapper that includes result JSON update logic.
Can be returned by the `RootValue` callable. Not used by Ariadne directly
but part of the support for Ariadne GraphQL Proxy.
# Attributes
- `root_value: Optional[dict]`: `RootValue` to use during query execution.
"""

__slots__ = ("root_value",)

root_value: Optional[dict]

def __init__(self, root_value: Optional[dict] = None):
self.root_value = root_value

def update_result(self, result: GraphQLResult) -> GraphQLResult:
"""An update function used to create a final `GraphQL` result tuple to
create a JSON response from.
Default implementation in `GraphQLResultUpdate` is a passthrough that
returns `result` value without any changes.
"""
return result

Check warning on line 258 in ariadne/types.py

View check run for this annotation

Codecov / codecov/patch

ariadne/types.py#L258

Added line #L258 was not covered by tests


"""Type of `query_parser` option of GraphQL servers.
Enables customization of server's GraphQL parsing logic. If not set or `None`,
Expand Down
1 change: 1 addition & 0 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ def type_defs():
testContext: String
testRoot: String
testError: Boolean
context: String
}
type Mutation {
Expand Down
36 changes: 35 additions & 1 deletion tests/test_graphql.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from graphql.validation.rules import ValidationRule

from ariadne import graphql, graphql_sync, subscribe
from ariadne.types import GraphQLResultUpdate


class AlwaysInvalid(ValidationRule):
Expand All @@ -12,6 +13,12 @@ def leave_operation_definition( # pylint: disable=unused-argument
self.context.report_error(GraphQLError("Invalid"))


class CustomGraphQLResultUpdate(GraphQLResultUpdate):
def update_result(self, result):
success, data = result
return success, dict(**data, updated=True)


def test_graphql_sync_executes_the_query(schema):
success, result = graphql_sync(schema, {"query": '{ hello(name: "world") }'})
assert success
Expand Down Expand Up @@ -51,8 +58,21 @@ def test_graphql_sync_prevents_introspection_query_when_option_is_disabled(schem
)


def test_graphql_sync_executes_the_query_using_result_update_obj(schema):
success, result = graphql_sync(
schema,
{"query": '{ context }'},
root_value=CustomGraphQLResultUpdate({"context": "Works!"}),
)
assert success
assert result == {
"data": {"context": "Works!"},
"updated": True,
}


@pytest.mark.asyncio
async def test_graphql_execute_the_query(schema):
async def test_graphql_executes_the_query(schema):
success, result = await graphql(schema, {"query": '{ hello(name: "world") }'})
assert success
assert result["data"] == {"hello": "Hello, world!"}
Expand Down Expand Up @@ -94,6 +114,20 @@ async def test_graphql_prevents_introspection_query_when_option_is_disabled(sche
)


@pytest.mark.asyncio
async def test_graphql_executes_the_query_using_result_update_obj(schema):
success, result = await graphql(
schema,
{"query": '{ context }'},
root_value=CustomGraphQLResultUpdate({"context": "Works!"}),
)
assert success
assert result == {
"data": {"context": "Works!"},
"updated": True,
}


@pytest.mark.asyncio
async def test_subscription_returns_an_async_iterator(schema):
success, result = await subscribe(schema, {"query": "subscription { ping }"})
Expand Down

0 comments on commit f7c6015

Please sign in to comment.