diff --git a/.codecov.yml b/.codecov.yml index 5b3d02564ba..020108ccc55 100644 --- a/.codecov.yml +++ b/.codecov.yml @@ -30,6 +30,7 @@ ignore: - "tests/**" - "ws_messages_pb2.py" - "cylc/flow/scripts/report_timings.py" + - "cylc/flow/network/graphql_subscribe.py" flag_management: default_rules: diff --git a/changes.d/6478.feat.md b/changes.d/6478.feat.md new file mode 100644 index 00000000000..0dc904dc87a --- /dev/null +++ b/changes.d/6478.feat.md @@ -0,0 +1 @@ +Major version upgrade for graphene/graphql-core dependencies. diff --git a/conda-environment.yml b/conda-environment.yml index fed33d57a88..fa670b7a246 100644 --- a/conda-environment.yml +++ b/conda-environment.yml @@ -5,7 +5,8 @@ dependencies: - ansimarkup >=1.0.0 - async-timeout>=3.0.0 # [py<3.11] - colorama >=0.4,<1.0 - - graphene >=2.1,<3 + - graphql-core >=3.2,<3.3 + - graphene >=3.4.0,<3.5 - graphviz # for static graphing # Note: can't pin jinja2 any higher than this until we give up on Cylc 7 back-compat - jinja2 >=3.0,<3.1 diff --git a/cylc/flow/commands.py b/cylc/flow/commands.py index 555f7609e40..b882fdeaa05 100644 --- a/cylc/flow/commands.py +++ b/cylc/flow/commands.py @@ -66,7 +66,6 @@ List, Optional, TypeVar, - Union, ) from metomi.isodatetime.parsers import TimePointParser @@ -81,7 +80,6 @@ import cylc.flow.flags from cylc.flow.flow_mgr import get_flow_nums_set from cylc.flow.log_level import log_level_to_verbosity -from cylc.flow.network.schema import WorkflowStopMode from cylc.flow.parsec.exceptions import ParsecError from cylc.flow.run_modes import RunMode from cylc.flow.task_id import TaskID @@ -90,6 +88,8 @@ if TYPE_CHECKING: + from enum import Enum + from cylc.flow.scheduler import Scheduler # define a type for command implementations @@ -170,7 +170,7 @@ async def set_prereqs_and_outputs( @_command('stop') async def stop( schd: 'Scheduler', - mode: Union[str, 'StopMode'], + mode: 'Optional[Enum]', cycle_point: Optional[str] = None, # NOTE clock_time YYYY/MM/DD-HH:mm back-compat removed clock_time: Optional[str] = None, @@ -208,12 +208,14 @@ async def stop( schd._update_workflow_state() else: # immediate shutdown - with suppress(KeyError): - # By default, mode from mutation is a name from the - # WorkflowStopMode graphene.Enum, but we need the value - mode = WorkflowStopMode[mode] # type: ignore[misc] try: - mode = StopMode(mode) + # BACK COMPAT: mode=None + # the mode can be `None` for commands issued from older Cylc + # versions + # From: 8.4 + # To: 8.5 + # Remove at: 8.x + mode = StopMode(mode.value) if mode else StopMode.REQUEST_CLEAN except ValueError: raise CommandFailedError(f"Invalid stop mode: '{mode}'") from None schd._set_stop(mode) @@ -303,14 +305,13 @@ async def pause(schd: 'Scheduler'): @_command('set_verbosity') -async def set_verbosity(schd: 'Scheduler', level: Union[int, str]): +async def set_verbosity(schd: 'Scheduler', level: 'Enum'): """Set workflow verbosity.""" try: - lvl = int(level) - LOG.setLevel(lvl) + LOG.setLevel(level.value) except (TypeError, ValueError) as exc: raise CommandFailedError(exc) from None - cylc.flow.flags.verbosity = log_level_to_verbosity(lvl) + cylc.flow.flags.verbosity = log_level_to_verbosity(level.value) yield diff --git a/cylc/flow/flow_mgr.py b/cylc/flow/flow_mgr.py index 67f816982ec..b2dea9d4ace 100644 --- a/cylc/flow/flow_mgr.py +++ b/cylc/flow/flow_mgr.py @@ -41,7 +41,7 @@ def add_flow_opts(parser): parser.add_option( - "--flow", action="append", dest="flow", metavar="FLOW", + "--flow", action="append", dest="flow", metavar="FLOW", default=[], help=f'Assign new tasks to all active flows ("{FLOW_ALL}");' f' no flow ("{FLOW_NONE}"); a new flow ("{FLOW_NEW}");' f' or a specific flow (e.g. "2"). The default is "{FLOW_ALL}".' diff --git a/cylc/flow/network/graphql.py b/cylc/flow/network/graphql.py index 1f49a6ee2fd..7bfa3f8a6cf 100644 --- a/cylc/flow/network/graphql.py +++ b/cylc/flow/network/graphql.py @@ -19,32 +19,27 @@ """ -from functools import partial -from inspect import isclass, iscoroutinefunction +from inspect import isclass import logging -from typing import TYPE_CHECKING, Any, Tuple, Union +from typing import ( + Any, Awaitable, Callable, TypeVar, Tuple, Dict, Union, cast +) from graphene.utils.str_converters import to_snake_case -from graphql.execution.utils import ( - get_operation_root_type, get_field_def +from graphql import ( + ExecutionContext, + TypeInfo, + TypeInfoVisitor, + Visitor, + visit, + get_named_type, + is_introspection_type, + value_from_ast_untyped ) -from graphql.execution.values import get_argument_values, get_variable_values -from graphql.language.base import parse, print_ast -from graphql.language import ast -from graphql.backend.base import GraphQLBackend, GraphQLDocument -from graphql.backend.core import execute_and_validate -from graphql.utils.base import type_from_ast -from graphql.type.definition import get_named_type -from promise import Promise -from rx import Observable +from graphql.pyutils import AwaitableOrValue, is_awaitable from cylc.flow.network.schema import NODE_MAP -if TYPE_CHECKING: - from graphql.execution import ExecutionResult - from graphql.language.ast import Document - from graphql.type.schema import GraphQLSchema - logger = logging.getLogger(__name__) @@ -53,6 +48,8 @@ EMPTY_VALUES: Tuple[list, dict] = ([], {}) STRIP_OPS = {'query', 'subscription'} +U = TypeVar("U") + def grow_tree(tree, path, leaves=None): """Additively grows tree with leaves at terminal of new branch. @@ -94,6 +91,30 @@ def instantiate_middleware(middlewares): yield middleware +async def async_callback( + callback: Callable[[U], AwaitableOrValue[U]], + result: AwaitableOrValue[U], +) -> U: + """Await result and apply callback.""" + result = callback(await cast('Awaitable[Any]', result)) + return await result if is_awaitable(result) else result # type: ignore + + +def async_next( + callback: Callable[[U], AwaitableOrValue[U]], + result: AwaitableOrValue[U], +) -> AwaitableOrValue[U]: + """Reduce the given potentially awaitable values using a callback function. + + If the callback does not return an awaitable, then this function will also + not return an awaitable. + """ + if is_awaitable(result): + return async_callback(callback, result) + else: + return callback(cast('U', result)) + + def null_setter(result): """Set type to null if result is empty/null-like.""" # Only set empty parents to null. @@ -111,8 +132,6 @@ def null_setter(result): # However, middleware allows for argument of the request doc to set. def strip_null(data): """Recursively strip data structure of nulls.""" - if isinstance(data, Promise): - return data.then(strip_null) if isinstance(data, dict): return { key: strip_null(val) @@ -128,200 +147,70 @@ def strip_null(data): return data -def attr_strip_null(result): - """Work on the attribute/data of ExecutionResult if present.""" - if hasattr(result, 'data'): - result.data = strip_null(result.data) - return result - return strip_null(result) - - -def null_stripper(exe_result): - """Strip nulls in accordance with type of execution result.""" - if isinstance(exe_result, Observable): - return exe_result.map(attr_strip_null) - if not exe_result.errors: - return attr_strip_null(exe_result) - return exe_result - - -class AstDocArguments: - """Request doc Argument inspection.""" - - def __init__(self, schema, document_ast, variable_values): - self.schema = schema - self.operation_defs = {} - self.fragment_defs = {} - self.visited_fragments = set() - - for defn in document_ast.definitions: - if isinstance(defn, ast.OperationDefinition): - root_type = get_operation_root_type(schema, defn) - definition_variables = defn.variable_definitions or [] - if definition_variables: - def_var_names = { - v.variable.name.value - for v in definition_variables - } - var_names_diff = def_var_names.difference({ - k - for k in variable_values - if k in def_var_names - }) - # check if we are missing some of the definition variables - if var_names_diff: - msg = (f'Please check your query variables. The ' - f'following variables are missing: ' - f'[{", ".join(var_names_diff)}]') - raise ValueError(msg) - self.operation_defs[getattr(defn.name, 'value', root_type)] = { - 'definition': defn, - 'parent_type': root_type, - 'variables': get_variable_values( - schema, - definition_variables, - variable_values - ), - } - elif isinstance(defn, ast.FragmentDefinition): - self.fragment_defs[defn.name.value] = defn - - def has_arg_val(self, arg_name, arg_value): - """Search through document definitions for argument value. - - Args: - arg_name (str): Field argument to search for. - arg_value (Any): Argument value required. - - Returns: - - Boolean - - """ - for components in self.operation_defs.values(): - defn = components['definition'] - if ( - defn.operation not in STRIP_OPS - or getattr( - defn.name, 'value', None) == 'IntrospectionQuery' - ): - continue - if self.args_selection_search( - components['definition'].selection_set, - components['variables'], - components['parent_type'], - arg_name, - arg_value, - ): - return True - return False - - def args_selection_search( - self, selection_set, variables, parent_type, arg_name, arg_value): - """Recursively search through feild/fragment selection set fields.""" - for field in selection_set.selections: - if isinstance(field, ast.FragmentSpread): - if field.name.value in self.visited_fragments: - continue - frag_def = self.fragment_defs[field.name.value] - frag_type = type_from_ast(self.schema, frag_def.type_condition) - if self.args_selection_search( - frag_def.selection_set, variables, - frag_type, arg_name, arg_value): - return True - self.visited_fragments.add(frag_def.name) - continue - field_def = get_field_def( - self.schema, parent_type, field.name.value) - if field_def is None: - continue - arg_vals = get_argument_values( - field_def.args, field.arguments, variables) - if arg_vals.get(arg_name) == arg_value: - return True - if field.selection_set is None: - continue - if self.args_selection_search( - field.selection_set, variables, - get_named_type(field_def.type), arg_name, arg_value): - return True - return False - - -def execute_and_validate_and_strip( - schema: 'GraphQLSchema', - document_ast: 'Document', - *args: Any, - **kwargs: Any -) -> Union['ExecutionResult', Observable]: - """Wrapper around graphql ``execute_and_validate()`` that adds - null stripping.""" - result = execute_and_validate(schema, document_ast, *args, **kwargs) - # Search request document to determine if 'stripNull: true' is set - # as and argument. It can not be done in the middleware, as they - # can be Promises/futures (so may not been resolved at this point). - variable_values = kwargs['variable_values'] or {} - doc_args = AstDocArguments(schema, document_ast, variable_values) - if doc_args.has_arg_val(STRIP_ARG, True): - if kwargs.get('return_promise', False) and hasattr(result, 'then'): - return result.then(null_stripper) # type: ignore[union-attr] - return null_stripper(result) - return result - - -class CylcGraphQLBackend(GraphQLBackend): - """Return a GraphQL document using the default - graphql executor with optional null-stripping of result. - - The null value stripping of result is triggered by the presence - of argument & value "stripNull: true" in any field. - - This is a modification of GraphQLCoreBackend found within: - https://github.com/graphql-python/graphql-core-legacy - (graphql-core==2.3.2) - - Args: - - executor (object): Executor used in evaluating the resolvers. +class CylcVisitor(Visitor): + """Traverse graphql document/query to find an argument in a given state. + Find whether an argument is set to a specific value anywhere in the + document (i.e. 'strip_null' set to 'True'), and stop on the first + occurrence. """ + def __init__(self, type_info, variable_values, doc_arg) -> None: + super().__init__() + self.type_info = type_info + self.variable_values = variable_values + self.doc_arg = doc_arg + self.arg_flag = False + + def enter(self, node, key, parent, path, ancestors): + if ( + node.kind == 'argument' + and node.name.value in self.doc_arg + and self.doc_arg[node.name.value] == value_from_ast_untyped( + node.value, + self.variable_values + ) + ): + self.arg_flag = True + return self.BREAK + return self.IDLE - def __init__(self, executor=None): - self.execute_params = {"executor": executor} - - def document_from_string(self, schema, document_string): - """Parse string and setup request document for execution. - - Args: + def leave(self, node, key, parent, path, ancestors): + return self.IDLE - schema (graphql.GraphQLSchema): - Schema definition object - document_string (str): - Request query/mutation/subscription document. - Returns: +class CylcExecutionContext(ExecutionContext): - graphql.GraphQLDocument + def execute_operation( + self, operation, root_value + ) -> AwaitableOrValue[Union[Dict[str, Any], Any, None]]: + """Execute the GraphQL document, and apply requested stipping. + Search request document to determine if 'stripNull: true' is set + as and argument. It can not be done in the middleware, as they + can have awaitables and is prior to validation. """ - if isinstance(document_string, ast.Document): - document_ast = document_string - document_string = print_ast(document_ast) - else: - if not isinstance(document_string, str): - logger.error("The query must be a string") - document_ast = parse(document_string) - return GraphQLDocument( - schema=schema, - document_string=document_string, - document_ast=document_ast, - execute=partial( - execute_and_validate_and_strip, - schema, - document_ast, - **self.execute_params + result = super().execute_operation(operation, root_value) + + # Traverse the document stop if found + type_info = TypeInfo(self.schema) + cylc_visitor = CylcVisitor( + type_info, + self.variable_values, + { + 'stripNull': True + } + ) + visit( + self.operation, + TypeInfoVisitor( + type_info, + cylc_visitor ), + None ) + if cylc_visitor.arg_flag: + return async_next(strip_null, result) # type: ignore + return result # -- Middleware -- @@ -329,10 +218,6 @@ def document_from_string(self, schema, document_string): class IgnoreFieldMiddleware: """Set to null/None type undesired field values for stripping.""" - # Sometimes `next` is a Partial(coroutine) or Promise, - # making inspection for know how to resolve it difficult. - ASYNC_OPS = {'query', 'mutation'} - def __init__(self): self.args_tree = {} self.tree_paths = set() @@ -341,21 +226,23 @@ def __init__(self): def resolve(self, next_, root, info, **args): """Middleware resolver; handles field according to operation.""" # GraphiQL introspection is 'query' but not async - if getattr(info.operation.name, 'value', None) == 'IntrospectionQuery': + if is_introspection_type(get_named_type(info.return_type)): return next_(root, info, **args) - if info.operation.operation in STRIP_OPS: - path_string = f'{info.path}' + if info.operation.operation.value in STRIP_OPS: + path_list = info.path.as_list() + path_string = f'{path_list}' + parent_path_string = f'{path_list[:-1:]}' # Needed for child fields that resolve without args. # Store arguments of parents as leaves of schema tree from path # to respective field. # no need to regrow the tree on every subscription push/delta if args and path_string not in self.tree_paths: - grow_tree(self.args_tree, info.path, args) + grow_tree(self.args_tree, path_list, args) self.tree_paths.add(path_string) if STRIP_ARG not in args: branch = self.args_tree - for section in info.path: + for section in path_list: if section not in branch: break branch = branch[section] @@ -381,7 +268,6 @@ def resolve(self, next_, root, info, **args): ): # Gather fields set in root - parent_path_string = f'{info.path[:-1:]}' stamp = getattr(root, 'stamp', '') if ( parent_path_string not in self.field_sets @@ -414,25 +300,6 @@ def resolve(self, next_, root, info, **args): ) ): return None - if ( - info.operation.operation in self.ASYNC_OPS - or iscoroutinefunction(next_) - ): - return self.async_null_setter(next_, root, info, **args) - return null_setter(next_(root, info, **args)) + return async_next(null_setter, next_(root, info, **args)) - if ( - info.operation.operation in self.ASYNC_OPS - or iscoroutinefunction(next_) - ): - return self.async_resolve(next_, root, info, **args) return next_(root, info, **args) - - async def async_resolve(self, next_, root, info, **args): - """Return awaited coroutine""" - return await next_(root, info, **args) - - async def async_null_setter(self, next_, root, info, **args): - """Set type to null after awaited result if empty/null-like.""" - result = await next_(root, info, **args) - return null_setter(result) diff --git a/cylc/flow/network/graphql_subscribe.py b/cylc/flow/network/graphql_subscribe.py new file mode 100644 index 00000000000..a1a67904740 --- /dev/null +++ b/cylc/flow/network/graphql_subscribe.py @@ -0,0 +1,282 @@ +# MIT License +# +# Copyright (c) GraphQL Contributors (GraphQL.js) +# Copyright (c) Syrus Akbary (GraphQL-core 2) +# Copyright (c) Christoph Zwerschke (GraphQL-core 3) +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# ---------------------------------------------------------------------------- +# +# The code in this file originates from graphql-core +# https://github.com/graphql-python/graphql-core/blob/v3.2.6/src/graphql/execution/subscribe.py +# +# It was modified to include `execution_context_class` and `middleware` in +# the execution of GraphQL subscriptions. +# This should not be necessary with some unspecified future releases as +# the head of graphql-core has these included. +# +# BACK COMPAT: graphql_subscribe.py +# FROM: graphql-core 3.2 +# TO: graphql-core 3.3 +# URL: https://github.com/cylc/cylc-flow/issues/6688 + + +from inspect import isawaitable +from typing import ( + TYPE_CHECKING, + Any, + AsyncIterable, + AsyncIterator, + Dict, + Optional, + Type, + Union, +) + +from graphql.error import GraphQLError, located_error +from graphql.execution.collect_fields import collect_fields +from graphql.execution.execute import ( + assert_valid_execution_arguments, + execute, + get_field_def, + ExecutionContext, + ExecutionResult, + Middleware, +) +from graphql.execution.values import get_argument_values +from graphql.pyutils import Path, inspect +from graphql.execution.map_async_iterator import MapAsyncIterator + + +if TYPE_CHECKING: + from graphql.language import DocumentNode + from graphql.type import GraphQLFieldResolver, GraphQLSchema + + +__all__ = ["subscribe", "create_source_event_stream"] + + +async def subscribe( + schema: 'GraphQLSchema', + document: 'DocumentNode', + root_value: Any = None, + context_value: Any = None, + variable_values: Optional[Dict[str, Any]] = None, + operation_name: Optional[str] = None, + field_resolver: 'Optional[GraphQLFieldResolver]' = None, + subscribe_field_resolver: 'Optional[GraphQLFieldResolver]' = None, + middleware: Optional[Middleware] = None, + execution_context_class: Optional[Type["ExecutionContext"]] = None, + subscribe_resolver_map: 'Optional[Dict[str, GraphQLFieldResolver]]' = None, +) -> Union[AsyncIterator[ExecutionResult], ExecutionResult]: + """Create a GraphQL subscription. + + Implements the "Subscribe" algorithm described in the GraphQL spec. + + Returns a coroutine object which yields either an AsyncIterator + (if successful) or an ExecutionResult (client error). The coroutine will + raise an exception if a server error occurs. + + If the client-provided arguments to this function do not result in a + compliant subscription, a GraphQL Response (ExecutionResult) with + descriptive errors and no data will be returned. + + If the source stream could not be created due to faulty subscription + resolver logic or underlying systems, the coroutine object will yield a + single ExecutionResult containing ``errors`` and no ``data``. + + If the operation succeeded, the coroutine will yield an AsyncIterator, + which yields a stream of ExecutionResults representing the response stream. + """ + result_or_stream = await create_source_event_stream( + schema, + document, + root_value, + context_value, + variable_values, + operation_name, + subscribe_field_resolver, + middleware, + execution_context_class, + subscribe_resolver_map + ) + if isinstance(result_or_stream, ExecutionResult): + return result_or_stream + + async def map_source_to_response(payload: Any) -> ExecutionResult: + """Map source to response. + + For each payload yielded from a subscription, map it over the normal + GraphQL :func:`~graphql.execute` function, with ``payload`` as the + ``root_value``. This implements the "MapSourceToResponseEvent" + algorithm described in the GraphQL specification. The + :func:`~graphql.execute` function provides the + "ExecuteSubscriptionEvent" algorithm, as it is nearly identical to the + "ExecuteQuery" algorithm, for which :func:`~graphql.execute` is also + used. + """ + result = execute( + schema, + document, + payload, + context_value, + variable_values, + operation_name, + field_resolver, + middleware=middleware, + execution_context_class=execution_context_class, + ) + return await result if isawaitable(result) else result + + # Map every source value to a ExecutionResult value as described above. + return MapAsyncIterator(result_or_stream, map_source_to_response) + + +async def create_source_event_stream( + schema: 'GraphQLSchema', + document: 'DocumentNode', + root_value: Any = None, + context_value: Any = None, + variable_values: Optional[Dict[str, Any]] = None, + operation_name: Optional[str] = None, + subscribe_field_resolver: 'Optional[GraphQLFieldResolver]' = None, + middleware: Optional[Middleware] = None, + execution_context_class: Optional[Type["ExecutionContext"]] = None, + subscribe_resolver_map: 'Optional[Dict[str, GraphQLFieldResolver]]' = None, +) -> Union[AsyncIterable[Any], ExecutionResult]: + """Create source event stream + + Implements the "CreateSourceEventStream" algorithm described in the GraphQL + specification, resolving the subscription source event stream. + + Returns a coroutine that yields an AsyncIterable. + + If the client-provided arguments to this function do not result in a + compliant subscription, a GraphQL Response (ExecutionResult) with + descriptive errors and no data will be returned. + + If the source stream could not be created due to faulty subscription + resolver logic or underlying systems, the coroutine object will yield a + single ExecutionResult containing ``errors`` and no ``data``. + + A source event stream represents a sequence of events, each of which + triggers a GraphQL execution for that event. + + This may be useful when hosting the stateful subscription service in a + different process or machine than the stateless GraphQL execution engine, + or otherwise separating these two steps. For more on this, see the + "Supporting Subscriptions at Scale" information in the GraphQL spec. + """ + # If arguments are missing or incorrectly typed, this is an internal + # developer mistake which should throw an early error. + assert_valid_execution_arguments(schema, document, variable_values) + + if execution_context_class is None: + execution_context_class = ExecutionContext + + # If a valid context cannot be created due to incorrect arguments, + # a "Response" with only errors is returned. + context = execution_context_class.build( + schema, + document, + root_value, + context_value, + variable_values, + operation_name, + subscribe_field_resolver=subscribe_field_resolver, + ) + + # Return early errors if execution context failed. + if isinstance(context, list): + return ExecutionResult(data=None, errors=context) + + try: + event_stream = await execute_subscription( + context, + subscribe_resolver_map + ) + + # Assert field returned an event stream, otherwise yield an error. + if not isinstance(event_stream, AsyncIterable): + raise TypeError( + "Subscription field must return AsyncIterable." + f" Received: {inspect(event_stream)}." + ) + return event_stream + + except GraphQLError as error: + # Report it as an ExecutionResult, containing only errors and no data. + return ExecutionResult(data=None, errors=[error]) + + +async def execute_subscription( + context: ExecutionContext, + subscribe_resolver_map: 'Optional[Dict[str, GraphQLFieldResolver]]' = None, +) -> AsyncIterable[Any]: + schema = context.schema + + root_type = schema.subscription_type + if root_type is None: + raise GraphQLError( + "Schema is not configured to execute subscription operation.", + context.operation, + ) + + root_fields = collect_fields( + schema, + context.fragments, + context.variable_values, + root_type, + context.operation.selection_set, + ) + response_name, field_nodes = next(iter(root_fields.items())) + field_def = get_field_def(schema, root_type, field_nodes[0]) + + if not field_def: + field_name = field_nodes[0].name.value + raise GraphQLError( + f"The subscription field '{field_name}' is not defined.", + field_nodes + ) + + path = Path(None, response_name, root_type.name) + info = context.build_resolve_info(field_def, field_nodes, root_type, path) + + # Call the `subscribe()` resolver or the default resolver to produce an + # AsyncIterable yielding raw payloads. + if subscribe_resolver_map is not None: + resolve_fn = subscribe_resolver_map.get( + info.field_name, + field_def.subscribe or context.subscribe_field_resolver + ) + else: + resolve_fn = field_def.subscribe or context.subscribe_field_resolver + + # Implements the "ResolveFieldEventStream" algorithm from GraphQL + # specification. It differs from "ResolveFieldValue" due to providing a + # different `resolveFn`. + + try: + # Build a dictionary of arguments from the field.arguments AST, using + # the variables scope to fulfill any variable references. + args = get_argument_values( + field_def, field_nodes[0], context.variable_values) + + event_stream = resolve_fn(context.root_value, info, **args) + if context.is_awaitable(event_stream): + event_stream = await event_stream + if isinstance(event_stream, Exception): + raise event_stream + + return event_stream + except Exception as error: + raise located_error(error, field_nodes, path.as_list()) from error diff --git a/cylc/flow/network/resolvers.py b/cylc/flow/network/resolvers.py index b97a56a93b3..c2eedebaca4 100644 --- a/cylc/flow/network/resolvers.py +++ b/cylc/flow/network/resolvers.py @@ -46,6 +46,7 @@ DELTA_ADDED, create_delta_store ) import cylc.flow.flags +from cylc.flow.flow_mgr import FLOW_ALL from cylc.flow.id import Tokens from cylc.flow.network.schema import ( DEF_TYPES, @@ -58,8 +59,9 @@ from cylc.flow.util import uniq, iter_uniq if TYPE_CHECKING: + from enum import Enum from uuid import UUID - from graphql import ResolveInfo + from graphql import GraphQLResolveInfo from cylc.flow.data_store_mgr import DataStoreMgr from cylc.flow.scheduler import Scheduler @@ -507,7 +509,7 @@ async def get_nodes_edges(self, root_nodes, args): edges=sort_elements(edges, args)) async def subscribe_delta( - self, root, info: 'ResolveInfo', args + self, root, info: 'GraphQLResolveInfo', args ) -> AsyncGenerator[Any, None]: """Delta subscription async generator. @@ -518,8 +520,9 @@ async def subscribe_delta( # NOTE: we don't expect workflows to be returned in definition order # so it is ok to use `set` here workflow_ids = set(args.get('workflows', args.get('ids', ()))) + sub_id = uuid4() - info.variable_values['backend_sub_id'] = sub_id + info.context['sub_id'] = sub_id self.delta_store[sub_id] = {} op_id = root @@ -638,7 +641,7 @@ async def flow_delta_processed(self, context, op_id): @abstractmethod async def mutator( self, - info: 'ResolveInfo', + info: 'GraphQLResolveInfo', command: str, w_args: Dict[str, Any], kwargs: Dict[str, Any], @@ -659,7 +662,7 @@ def __init__(self, data: 'DataStoreMgr', schd: 'Scheduler') -> None: # Mutations async def mutator( self, - _info: 'ResolveInfo', + _info: 'GraphQLResolveInfo', command: str, w_args: Dict[str, Any], kwargs: Dict[str, Any], @@ -673,6 +676,17 @@ async def mutator( return [{ 'response': (False, f'No matching workflow in {workflows}')}] w_id = w_ids[0] + # BACK COMPAT: transform "None" to "[]" + # url: https://github.com/cylc/cylc-flow/pull/6478 + # from: <8.5.0 + # to: >=8.5.0 + # remove at: 8.x + # For back compat, with gql-v3 None will not use default_value. + if 'flow' in kwargs: + kwargs['flow'] = ( + kwargs['flow'] + or ([FLOW_ALL] if command == 'remove_tasks' else []) + ) result = await self._mutation_mapper(command, kwargs, meta) return [{'id': w_id, 'response': result}] @@ -691,11 +705,14 @@ async def _mutation_mapper( else: log_user = f" from {user}" - log1 = f'Command "{command}" received{log_user}.' - log2 = ( + received_msg = f'Command "{command}" received{log_user}.' + signature_str = ( f"{command}(" + ", ".join( - f"{key}={value}" for key, value in kwargs.items()) + f"{key}={getattr(value, 'value', value)}" + for key, value in kwargs.items() + if value is not None + ) + ")" ) @@ -706,7 +723,7 @@ async def _mutation_mapper( or user != self.schd.owner ): # Logging task messages as commands is overkill. - LOG.info(f"{log1}\n{log2}") + LOG.info(f"{received_msg}\n{signature_str}") return method(**kwargs) try: @@ -723,14 +740,14 @@ async def _mutation_mapper( except Exception as exc: # NOTE: keep this exception vague to prevent a bad command taking # down the scheduler - LOG.warning(f'{log1}\n{exc.__class__.__name__}: {exc}') + LOG.warning(f'{received_msg}\n{exc.__class__.__name__}: {exc}') if cylc.flow.flags.verbosity > 1: LOG.exception(exc) # log full traceback in debug mode return (False, str(exc)) # Queue the command to the scheduler, with a unique command ID cmd_uuid = str(uuid4()) - LOG.info(f"{log1} ID={cmd_uuid}\n{log2}") + LOG.info(f"{received_msg} ID={cmd_uuid}\n{signature_str}") self.schd.command_queue.put( ( cmd_uuid, @@ -742,7 +759,7 @@ async def _mutation_mapper( def broadcast( self, - mode: str, + mode: 'Enum', cycle_points: Optional[List[str]] = None, namespaces: Optional[List[str]] = None, settings: Optional[List[Dict[str, str]]] = None, @@ -756,15 +773,15 @@ def broadcast( for i, dict_ in enumerate(settings): settings[i] = runtime_schema_to_cfg(dict_) - if mode == 'put_broadcast': + if mode.value == 'put_broadcast': return self.schd.task_events_mgr.broadcast_mgr.put_broadcast( cycle_points, namespaces, settings) - if mode == 'clear_broadcast': + if mode.value == 'clear_broadcast': return self.schd.task_events_mgr.broadcast_mgr.clear_broadcast( point_strings=cycle_points, namespaces=namespaces, cancel_settings=settings) - if mode == 'expire_broadcast': + if mode.value == 'expire_broadcast': return ( self.schd.task_events_mgr.broadcast_mgr.expire_broadcast( cutoff diff --git a/cylc/flow/network/schema.py b/cylc/flow/network/schema.py index 801a9b41325..18a5206b3c6 100644 --- a/cylc/flow/network/schema.py +++ b/cylc/flow/network/schema.py @@ -28,7 +28,6 @@ List, Optional, Tuple, - Union, cast, ) @@ -48,6 +47,7 @@ String, ) from graphene.types.generic import GenericScalar +from graphene.types.schema import identity_resolve from graphene.utils.str_converters import to_snake_case from graphql.type.definition import get_named_type @@ -72,31 +72,22 @@ ) from cylc.flow.id import Tokens from cylc.flow.run_modes import ( - WORKFLOW_RUN_MODES, RunMode) -from cylc.flow.task_outputs import SORT_ORDERS -from cylc.flow.task_state import ( - TASK_STATUS_DESC, - TASK_STATUS_EXPIRED, - TASK_STATUS_FAILED, - TASK_STATUS_PREPARING, - TASK_STATUS_RUNNING, - TASK_STATUS_SUBMIT_FAILED, - TASK_STATUS_SUBMITTED, - TASK_STATUS_SUCCEEDED, - TASK_STATUS_WAITING, - TASK_STATUSES_ORDERED, + WORKFLOW_RUN_MODES, + RunMode, ) +from cylc.flow.task_outputs import SORT_ORDERS +from cylc.flow.task_state import TASK_STATUSES_ORDERED from cylc.flow.util import sstrip from cylc.flow.workflow_status import StopMode if TYPE_CHECKING: from enum import Enum - from graphql import ResolveInfo + + from graphql import GraphQLResolveInfo from graphql.type.definition import ( - GraphQLList, GraphQLNamedType, - GraphQLNonNull, + GraphQLType, ) from cylc.flow.network.resolvers import BaseResolvers @@ -288,9 +279,7 @@ class SortArgs(InputObjectType): # Resolvers: -def field_name_from_type( - obj_type: 'Union[GraphQLNamedType, GraphQLList, GraphQLNonNull]' -) -> str: +def field_name_from_type(obj_type: 'GraphQLType') -> str: """Return the field name for given a GraphQL type. If the type is a list or non-null, the base field is extracted. @@ -302,19 +291,19 @@ def field_name_from_type( raise ValueError(f"'{named_type.name}' is not a node type") from None -def get_resolvers(info: 'ResolveInfo') -> 'BaseResolvers': +def get_resolvers(info: 'GraphQLResolveInfo') -> 'BaseResolvers': """Return the resolvers from the context.""" return cast('dict', info.context)['resolvers'] def process_resolver_info( - root: Optional[Any], info: 'ResolveInfo', args: Dict[str, Any] + root: Optional[Any], info: 'GraphQLResolveInfo', args: Dict[str, Any] ) -> Tuple[str, Optional[Any]]: """Set and gather info for resolver.""" # Add the subscription id to the resolver context # to know which delta-store to use.""" - if 'backend_sub_id' in info.variable_values: - args['sub_id'] = info.variable_values['backend_sub_id'] + if 'sub_id' in info.context: + args['sub_id'] = info.context['sub_id'] field_name: str = to_snake_case(info.field_name) # root is the parent data object. @@ -336,7 +325,7 @@ def get_native_ids(field_ids): return field_ids -async def get_workflows(root, info: 'ResolveInfo', **args): +async def get_workflows(root, info: 'GraphQLResolveInfo', **args): """Get filtered workflows.""" _, workflow = process_resolver_info(root, info, args) @@ -349,7 +338,7 @@ async def get_workflows(root, info: 'ResolveInfo', **args): return await resolvers.get_workflows(args) -async def get_workflow_by_id(root, info: 'ResolveInfo', **args): +async def get_workflow_by_id(root, info: 'GraphQLResolveInfo', **args): """Return single workflow element.""" _, workflow = process_resolver_info(root, info, args) @@ -362,7 +351,7 @@ async def get_workflow_by_id(root, info: 'ResolveInfo', **args): async def get_nodes_all( - root: Optional[Any], info: 'ResolveInfo', **args + root: Optional[Any], info: 'GraphQLResolveInfo', **args ): """Resolver for returning job, task, family nodes""" @@ -401,7 +390,7 @@ async def get_nodes_all( async def get_nodes_by_ids( - root: Optional[Any], info: 'ResolveInfo', **args + root: Optional[Any], info: 'GraphQLResolveInfo', **args ): """Resolver for returning job, task, family node""" field_name, field_ids = process_resolver_info(root, info, args) @@ -438,7 +427,7 @@ async def get_nodes_by_ids( async def get_node_by_id( - root: Optional[Any], info: 'ResolveInfo', **args + root: Optional[Any], info: 'GraphQLResolveInfo', **args ): """Resolver for returning job, task, family node""" @@ -476,7 +465,7 @@ async def get_node_by_id( ) -async def get_edges_all(root, info: 'ResolveInfo', **args): +async def get_edges_all(root, info: 'GraphQLResolveInfo', **args): """Get all edges from the store filtered by args.""" process_resolver_info(root, info, args) @@ -491,7 +480,7 @@ async def get_edges_all(root, info: 'ResolveInfo', **args): return await resolvers.get_edges_all(args) -async def get_edges_by_ids(root, info: 'ResolveInfo', **args): +async def get_edges_by_ids(root, info: 'GraphQLResolveInfo', **args): """Get all edges from the store by id lookup filtered by args.""" _, field_ids = process_resolver_info(root, info, args) @@ -505,7 +494,7 @@ async def get_edges_by_ids(root, info: 'ResolveInfo', **args): return await resolvers.get_edges_by_ids(args) -async def get_nodes_edges(root, info: 'ResolveInfo', **args): +async def get_nodes_edges(root, info: 'GraphQLResolveInfo', **args): """Resolver for returning job, task, family nodes""" process_resolver_info(root, info, args) @@ -545,7 +534,7 @@ def resolve_state_tasks(root, info, **args): if state in data} -async def resolve_broadcasts(root, info: 'ResolveInfo', **args): +async def resolve_broadcasts(root, info: 'GraphQLResolveInfo', **args): """Resolve and parse broadcasts from JSON.""" broadcasts = json.loads( getattr(root, to_snake_case(info.field_name), '{}')) @@ -866,7 +855,7 @@ class Meta: directives = graphene.List(RuntimeSetting, resolver=resolve_json_dump) environment = graphene.List(RuntimeSetting, resolver=resolve_json_dump) outputs = graphene.List(RuntimeSetting, resolver=resolve_json_dump) - run_mode = TaskRunMode(default_value=TaskRunMode.Live.name) + run_mode = TaskRunMode(default_value=TaskRunMode.Live) RUNTIME_FIELD_TO_CFG_MAP = { @@ -1538,7 +1527,7 @@ class Meta: async def mutator( root: Optional[Any], - info: 'ResolveInfo', + info: 'GraphQLResolveInfo', *, command: Optional[str] = None, workflows: Optional[List[str]] = None, @@ -1574,6 +1563,7 @@ async def mutator( if kwargs.get('args', False): kwargs.update(kwargs.get('args', {})) kwargs.pop('args') + resolvers = get_resolvers(info) meta = info.context.get('meta') # type: ignore[union-attr] res = await resolvers.mutator(info, command, w_args, kwargs, meta) @@ -1647,41 +1637,6 @@ class BroadcastCyclePoint(graphene.String): # (broadcast supports either of those two but not cycle point globs) -class TaskStatus(graphene.Enum): - """The status of a task in a workflow.""" - - # NOTE: this is an enumeration purely for the GraphQL schema - # TODO: the task statuses should be formally declared in a Python - # enumeration rendering this class unnecessary - Waiting = TASK_STATUS_WAITING - Expired = TASK_STATUS_EXPIRED - Preparing = TASK_STATUS_PREPARING - SubmitFailed = TASK_STATUS_SUBMIT_FAILED - Submitted = TASK_STATUS_SUBMITTED - Running = TASK_STATUS_RUNNING - Failed = TASK_STATUS_FAILED - Succeeded = TASK_STATUS_SUCCEEDED - - @property - def description(self): - return TASK_STATUS_DESC.get(self.value, '') - - -class TaskState(InputObjectType): - """The state of a task, a combination of status and other fields.""" - - status = TaskStatus() - is_held = Boolean(description=sstrip(''' - If a task is held no new job submissions will be made. - ''')) - is_queued = Boolean(description=sstrip(''' - Task is queued for job submission. - ''')) - is_runahead = Boolean(description=sstrip(''' - Task is runahead limited. - ''')) - - class TaskName(String): """The name a task. @@ -1730,7 +1685,6 @@ class WorkflowStopMode(graphene.Enum): """The mode used to stop a running workflow.""" # NOTE: using a different enum because: - # * Graphene requires special enums. # * We only want to offer a subset of stop modes (REQUEST_* only). Clean = cast('Enum', StopMode.REQUEST_CLEAN.value) @@ -1740,7 +1694,7 @@ class WorkflowStopMode(graphene.Enum): @property def description(self): - return StopMode(self.value).describe() + return StopMode(cast('Enum', self).value).describe() class Flow(String): @@ -1791,9 +1745,7 @@ class Arguments: workflows = graphene.List(WorkflowID, required=True) mode = BroadcastMode( - # use the enum name as the default value - # https://github.com/graphql-python/graphql-core-legacy/issues/166 - default_value=BroadcastMode.Set.name, + default_value=BroadcastMode.Set, description='What type of broadcast is this?', required=True ) @@ -2021,9 +1973,7 @@ class Meta: class Arguments: workflows = graphene.List(WorkflowID, required=True) - mode = WorkflowStopMode( - default_value=WorkflowStopMode.Clean.name - ) + mode = WorkflowStopMode(default_value=WorkflowStopMode.Clean) cycle_point = CyclePoint( description='Stop after the workflow reaches this cycle.' ) @@ -2348,7 +2298,8 @@ class Mutations(ObjectType): } -def delta_subs(root, info: 'ResolveInfo', **args) -> AsyncGenerator[Any, None]: +def delta_subs( + root, info: 'GraphQLResolveInfo', **args) -> AsyncGenerator[Any, None]: """Generates the root data from the async gen resolver.""" return get_resolvers(info).subscribe_delta(root, info, args) @@ -2360,12 +2311,12 @@ class Meta: the store. ''') workflow = String() - families = graphene.List(String, default_value=[]) - family_proxies = graphene.List(String, default_value=[]) - jobs = graphene.List(String, default_value=[]) - tasks = graphene.List(String, default_value=[]) - task_proxies = graphene.List(String, default_value=[]) - edges = graphene.List(String, default_value=[]) + families = graphene.List(String) + family_proxies = graphene.List(String) + jobs = graphene.List(String) + tasks = graphene.List(String) + task_proxies = graphene.List(String) + edges = graphene.List(String) class Delta(Interface): @@ -2539,6 +2490,28 @@ class Meta: ) +# TODO: Change to use subscribe arg/default. graphql-core has a subscribe field +# for both Meta and Field, graphene at v3.4.3 does not.. As a workaround +# the subscribe function is looked up via the following mapping: +# See https://github.com/cylc/cylc-flow/issues/6688 +SUB_RESOLVER_MAPPING = { + 'deltas': delta_subs, + 'workflows': delta_subs, + 'job': delta_subs, + 'jobs': delta_subs, + 'task': delta_subs, + 'tasks': delta_subs, + 'taskProxy': delta_subs, + 'taskProxies': delta_subs, + 'family': delta_subs, + 'families': delta_subs, + 'familyProxy': delta_subs, + 'familyProxies': delta_subs, + 'edges': delta_subs, + 'nodesEdges': delta_subs, +} + + class Subscriptions(ObjectType): """Defines the subscriptions available in the schema.""" class Meta: @@ -2553,7 +2526,7 @@ class Meta: strip_null=Boolean(default_value=False), initial_burst=Boolean(default_value=True), ignore_interval=Float(default_value=0.0), - resolver=delta_subs + resolver=identity_resolve ) workflows = graphene.List( Workflow, @@ -2566,7 +2539,7 @@ class Meta: delta_type=String(default_value=DELTA_ADDED), initial_burst=Boolean(default_value=True), ignore_interval=Float(default_value=2.5), - resolver=delta_subs + resolver=identity_resolve ) job = Field( Job, @@ -2577,7 +2550,7 @@ class Meta: delta_type=String(default_value=DELTA_ADDED), initial_burst=Boolean(default_value=True), ignore_interval=Float(default_value=0.0), - resolver=delta_subs + resolver=identity_resolve ) jobs = graphene.List( Job, @@ -2588,7 +2561,7 @@ class Meta: delta_type=String(default_value=DELTA_ADDED), initial_burst=Boolean(default_value=True), ignore_interval=Float(default_value=0.0), - resolver=delta_subs + resolver=identity_resolve ) task = Field( Task, @@ -2599,7 +2572,7 @@ class Meta: delta_type=String(default_value=DELTA_ADDED), initial_burst=Boolean(default_value=True), ignore_interval=Float(default_value=0.0), - resolver=delta_subs + resolver=identity_resolve ) tasks = graphene.List( Task, @@ -2610,7 +2583,7 @@ class Meta: delta_type=String(default_value=DELTA_ADDED), initial_burst=Boolean(default_value=True), ignore_interval=Float(default_value=0.0), - resolver=delta_subs + resolver=identity_resolve ) task_proxy = Field( TaskProxy, @@ -2621,7 +2594,7 @@ class Meta: delta_type=String(default_value=DELTA_ADDED), initial_burst=Boolean(default_value=True), ignore_interval=Float(default_value=0.0), - resolver=delta_subs + resolver=identity_resolve ) task_proxies = graphene.List( TaskProxy, @@ -2632,7 +2605,7 @@ class Meta: delta_type=String(default_value=DELTA_ADDED), initial_burst=Boolean(default_value=True), ignore_interval=Float(default_value=0.0), - resolver=delta_subs + resolver=identity_resolve ) family = Field( Family, @@ -2643,7 +2616,7 @@ class Meta: delta_type=String(default_value=DELTA_ADDED), initial_burst=Boolean(default_value=True), ignore_interval=Float(default_value=0.0), - resolver=delta_subs + resolver=identity_resolve ) families = graphene.List( Family, @@ -2654,7 +2627,7 @@ class Meta: delta_type=String(default_value=DELTA_ADDED), initial_burst=Boolean(default_value=True), ignore_interval=Float(default_value=0.0), - resolver=delta_subs + resolver=identity_resolve ) family_proxy = Field( FamilyProxy, @@ -2665,7 +2638,7 @@ class Meta: delta_type=String(default_value=DELTA_ADDED), initial_burst=Boolean(default_value=True), ignore_interval=Float(default_value=0.0), - resolver=delta_subs + resolver=identity_resolve ) family_proxies = graphene.List( FamilyProxy, @@ -2676,7 +2649,7 @@ class Meta: delta_type=String(default_value=DELTA_ADDED), initial_burst=Boolean(default_value=True), ignore_interval=Float(default_value=0.0), - resolver=delta_subs + resolver=identity_resolve ) edges = graphene.List( Edge, @@ -2687,7 +2660,7 @@ class Meta: delta_type=String(default_value=DELTA_ADDED), initial_burst=Boolean(default_value=True), ignore_interval=Float(default_value=0.0), - resolver=delta_subs + resolver=identity_resolve ) nodes_edges = Field( NodesEdges, @@ -2698,7 +2671,7 @@ class Meta: delta_type=String(default_value=DELTA_ADDED), initial_burst=Boolean(default_value=True), ignore_interval=Float(default_value=0.0), - resolver=delta_subs + resolver=identity_resolve ) diff --git a/cylc/flow/network/server.py b/cylc/flow/network/server.py index 7ea54bb2f2d..16243812cd3 100644 --- a/cylc/flow/network/server.py +++ b/cylc/flow/network/server.py @@ -29,7 +29,6 @@ Union, ) -from graphql.execution.executors.asyncio import AsyncioExecutor import zmq from zmq.auth.thread import ThreadAuthenticator @@ -42,7 +41,7 @@ from cylc.flow.data_messages_pb2 import PbEntireWorkflow from cylc.flow.data_store_mgr import DELTAS_MAP from cylc.flow.network.graphql import ( - CylcGraphQLBackend, + CylcExecutionContext, IgnoreFieldMiddleware, instantiate_middleware, ) @@ -53,8 +52,6 @@ if TYPE_CHECKING: - from graphql.execution import ExecutionResult - from cylc.flow.network import ResponseDict from cylc.flow.scheduler import Scheduler @@ -265,7 +262,7 @@ def operate(self) -> None: """Orchestrate the receive, send, publish of messages.""" # Note: this cannot be an async method because the response part # of the listener runs the event loop synchronously - # (in graphql AsyncioExecutor) + # (in graphql schema.execute_async) while True: if self.waiting_to_stop: # The self.stop() method is waiting for us to signal that we @@ -402,25 +399,24 @@ def graphql( Returns: object: Execution result, or a list with errors. """ - executed: 'ExecutionResult' = schema.execute( - request_string, - variable_values=variables, - context_value={ - 'resolvers': self.resolvers, - 'meta': meta or {}, - }, - backend=CylcGraphQLBackend(), - middleware=list(instantiate_middleware(self.middleware)), - executor=AsyncioExecutor(), - validate=True, # validate schema (dev only? default is True) - return_promise=False, + executed = self.loop.run_until_complete( + schema.execute_async( + request_string, + variable_values=variables, + context_value={ + 'resolvers': self.resolvers, + 'meta': meta or {}, + }, + middleware=list(instantiate_middleware(self.middleware)), + execution_context_class=CylcExecutionContext, + ) ) if executed.errors: for error in executed.errors: LOG.warning(f"GraphQL: {error}") # If there are execution errors, it means there was an unexpected # error, so fail the command. - raise Exception(*executed.errors) + raise Exception(*(error.message for error in executed.errors)) return executed.data # UIServer Data Commands diff --git a/cylc/flow/scripts/remove.py b/cylc/flow/scripts/remove.py index ce7a9c0113c..2c15a12bcd3 100755 --- a/cylc/flow/scripts/remove.py +++ b/cylc/flow/scripts/remove.py @@ -54,6 +54,7 @@ import sys from typing import TYPE_CHECKING +from cylc.flow.flow_mgr import FLOW_ALL from cylc.flow.network.client_factory import get_client from cylc.flow.network.multi import call_multi from cylc.flow.option_parsers import ( @@ -119,7 +120,7 @@ async def run(options: 'Values', workflow_id: str, *tokens_list): tokens.relative_id_with_selectors for tokens in tokens_list ], - 'flow': options.flow, + 'flow': options.flow or [FLOW_ALL], } } diff --git a/cylc/flow/scripts/stop.py b/cylc/flow/scripts/stop.py index 053ca0a4fa0..9cbdb338185 100755 --- a/cylc/flow/scripts/stop.py +++ b/cylc/flow/scripts/stop.py @@ -240,16 +240,24 @@ async def _run( options.max_polls, ) - # mode defaults to 'Clean' - mode = None if stop_task or stop_cycle: - pass + mode: Optional[str] = None elif options.kill: mode = WorkflowStopMode.Kill.name elif options.now > 1: mode = WorkflowStopMode.NowNow.name elif options.now: mode = WorkflowStopMode.Now.name + else: + mode = WorkflowStopMode.Clean.name + + if ( + options.flow_num is not None + and hasattr(options.flow_num, 'isdigit') + and options.flow_num.isdigit() + ): + # flow num gql type is int + options.flow_num = int(options.flow_num) mutation_kwargs = { 'request_string': MUTATION, @@ -260,7 +268,7 @@ async def _run( 'clockTime': options.wall_clock, 'task': stop_task, 'flowNum': options.flow_num - } + }, } ret = await pclient.async_request('graphql', mutation_kwargs) diff --git a/setup.cfg b/setup.cfg index 9ade7832161..89206b64fc2 100644 --- a/setup.cfg +++ b/setup.cfg @@ -65,7 +65,8 @@ install_requires = ansimarkup>=1.0.0 async-timeout>=3.0.0; python_version < "3.11" colorama>=0.4,<1 - graphene>=2.1,<3 + graphql-core>=3.2,<3.3 + graphene>=3.4.0,<3.5 # Note: can't pin jinja2 any higher than this until we give up on Cylc 7 back-compat jinja2==3.0.* metomi-isodatetime>=1!3.0.0,<1!3.2.0 @@ -78,8 +79,6 @@ install_requires = # NOTE: exclude two urwid versions that were not compatible with Tui urwid>=2.2,!=2.6.2,!=2.6.3,<3 # unpinned transient dependencies used for type checking - rx - promise tomli>=2; python_version < "3.11" [options.packages.find] diff --git a/tests/integration/network/test_graphql.py b/tests/integration/network/test_graphql.py index 6d8f4f63ca3..b285e410811 100644 --- a/tests/integration/network/test_graphql.py +++ b/tests/integration/network/test_graphql.py @@ -16,11 +16,23 @@ """Test the top-level (root) GraphQL queries.""" +from contextlib import suppress import pytest from typing import TYPE_CHECKING +from graphql import parse, MiddlewareManager + +from cylc.flow.data_store_mgr import create_delta_store from cylc.flow.id import Tokens from cylc.flow.network.client import WorkflowRuntimeClient +from cylc.flow.network.schema import schema, SUB_RESOLVER_MAPPING +from cylc.flow.network.graphql import ( + CylcExecutionContext, + IgnoreFieldMiddleware, + instantiate_middleware, +) +from cylc.flow.network.graphql_subscribe import subscribe +from cylc.flow.workflow_status import get_workflow_status if TYPE_CHECKING: from cylc.flow.scheduler import Scheduler @@ -52,6 +64,29 @@ def job_config(schd): } +def gather_subscription_args(schd, request_string): + kwargs = { + "variable_values": {}, + "operation_name": None, + "context_value": { + 'op_id': 1, + 'resolvers': schd.server.resolvers, + 'meta': {}, + }, + "subscribe_resolver_map": SUB_RESOLVER_MAPPING, + "middleware": MiddlewareManager( + *list( + instantiate_middleware( + [IgnoreFieldMiddleware] + ) + ) + ), + "execution_context_class": CylcExecutionContext, + } + document = parse(request_string) + return (document, kwargs) + + @pytest.fixture def job_db_row(): return [ @@ -199,14 +234,14 @@ async def test_task_proxies(harness): w_tokens.duplicate( cycle='1', task=namespace, - ).id + ) # NOTE: task "d" is not in the n=1 window yet for namespace in ('a', 'b', 'c') ] ret['taskProxies'].sort(key=lambda x: x['id']) assert ret == { 'taskProxies': [ - {'id': id_} + {'id': id_.id} for id_ in ids ] } @@ -214,12 +249,26 @@ async def test_task_proxies(harness): # query "task" ret = await client.async_request( 'graphql', - {'request_string': 'query { taskProxy(id: "%s") { id } }' % ids[0]} + {'request_string': 'query { taskProxy(id: "%s") { id } }' % ids[0].id} ) assert ret == { - 'taskProxy': {'id': ids[0]} + 'taskProxy': {'id': ids[0].id} } + # query "taskProxies" fragment with null stripping + ret = await client.async_request( + 'graphql', + { + 'request_string': ''' + fragment wf on Workflow { + taskProxies (ids: ["%s"], stripNull: true) { id } + } + query { workflows (ids: ["%s"]) { ...wf } } + ''' % (ids[0].relative_id, ids[0].workflow_id) + } + ) + assert ret == {'workflows': [{'taskProxies': [{'id': ids[0].id}]}]} + async def test_family_proxies(harness): schd, client, w_tokens = harness @@ -348,3 +397,94 @@ async def test_jobs(harness): assert ret == { 'job': {'id': f'{j_id}'} } + + +# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ # +# Test the GraphQL subscription infrastructure. +# (currently only used at UIS) +@pytest.mark.asyncio(loop_scope="module") +async def test_subscription_basic(harness): + """Test a basic subscription that uses the resolver's sub_resolver code.""" + schd, _, w_tokens = harness + document, kwargs = gather_subscription_args( + schd, + 'subscription { workflows { id } }', + ) + subscription = await subscribe( + schema.graphql_schema, + document, + **kwargs + ) + has_item = False + with suppress(GeneratorExit): + async for response in subscription: + has_item = True + assert response.data['workflows'][0]['id'] == w_tokens.id + await subscription.aclose() + assert has_item + + +@pytest.mark.asyncio(loop_scope="module") +async def test_subscription_deltas(one, start): + """Test the full subscription with null-stripping and delta handling.""" + async with start(one): + document, kwargs = gather_subscription_args( + one, + ''' + subscription { + deltas (stripNull: true) { + id + added { + workflow { + id + host + status + } + } + updated { + workflow { + id + host + status + } + } + } + } + ''', + ) + subscription = await subscribe( + schema.graphql_schema, + document, + **kwargs + ) + aitem = await subscription.__anext__() + assert aitem.data['deltas']['added']['workflow'] == { + 'id': one.id, + 'host': one.host, + 'status': 'running', + } + # Workflow one is paused on start, but this hasn't been processed yet. + await one.update_data_structure() + assert ( + one.data_store_mgr.data[one.id]['workflow'].status + == get_workflow_status(one).value + ) + # Get the all delta, process, then add it to the subscription queue. + btopic, delta, _ = one.data_store_mgr.publish_deltas[-1] + _, sub_queue = next( + iter(one.data_store_mgr.delta_queues[one.id].items()) + ) + sub_queue.put( + ( + one.id, + btopic.decode('utf-8'), + create_delta_store(delta, one.id) + ) + ) + aitem = await subscription.__anext__() + assert aitem.data['deltas']['updated']['workflow'] == { + 'id': one.id, + 'status': get_workflow_status(one).value, + } + with suppress(GeneratorExit): + await subscription.aclose() diff --git a/tests/integration/network/test_server.py b/tests/integration/network/test_server.py index 7df277c66dc..6e3f87b5c94 100644 --- a/tests/integration/network/test_server.py +++ b/tests/integration/network/test_server.py @@ -52,6 +52,23 @@ def test_graphql(myflow): assert myflow.id == data['workflows'][0]['id'] +def test_graphql_error(myflow): + """Test GraphQL endpoint method.""" + request_string = f''' + query {{ + workflows(ids: ["{myflow.id}"]) {{ + id + notafield + alsonotafield + }} + }} + ''' + with pytest.raises(Exception) as excinfo: + myflow.server.graphql(request_string) + assert "Cannot query field 'notafield'" in excinfo + assert "Cannot query field 'alsonotafield'" in excinfo + + def test_pb_data_elements(myflow): """Test Protobuf elements endpoint method.""" element_type = 'workflow' diff --git a/tests/unit/network/test_graphql.py b/tests/unit/network/test_graphql.py index e5be079e289..0b525bec16e 100644 --- a/tests/unit/network/test_graphql.py +++ b/tests/unit/network/test_graphql.py @@ -14,15 +14,18 @@ # You should have received a copy of the GNU General Public License # along with this program. If not, see . -from typing import Optional, Type - import pytest from pytest import param -from graphql import parse +from graphql import ( + TypeInfo, + TypeInfoVisitor, + parse, + visit +) from cylc.flow.data_messages_pb2 import PbTaskProxy, PbPrerequisite from cylc.flow.network.graphql import ( - AstDocArguments, null_setter, NULL_VALUE, grow_tree + CylcVisitor, null_setter, strip_null, async_next, NULL_VALUE, grow_tree ) from cylc.flow.network.schema import schema @@ -34,8 +37,8 @@ @pytest.mark.parametrize( 'query,' 'variables,' - 'expected_variables,' - 'expected_error', + 'search_arg,' + 'expected_result', [ pytest.param( ''' @@ -49,9 +52,9 @@ 'workflowID': 'cylc|workflow' }, { - 'workflowID': 'cylc|workflow' + 'ids': ['cylc|workflow'], }, - None, + True, id="simple query with correct variables" ), pytest.param( @@ -69,9 +72,9 @@ 'workflowID': 'cylc|workflow' }, { - 'workflowID': 'cylc|workflow' + 'ids': ['cylc|workflow'], }, - None, + True, id="query with a fragment and correct variables" ), pytest.param( @@ -83,48 +86,65 @@ } ''', { - 'workflowId': 'cylc|workflow' + 'workflowID': 'cylc|workflow' }, - None, - ValueError, + { + 'ids': None, + }, + False, id="correct variable definition, but missing variable in " "provided values" + ), + pytest.param( + ''' + query ($workflowID: ID) { + workflows (ids: [$workflowID]) { + id + } + } + ''', + { + 'workflowID': 'cylc|workflow' + }, + { + 'idfsdf': ['cylc|workflow'], + }, + False, + id="correct variable definition, but wrong search argument" ) ] ) def test_query_variables( query: str, variables: dict, - expected_variables: Optional[dict], - expected_error: Optional[Type[Exception]], + search_arg: dict, + expected_result: bool, ): - """Test that query variables are parsed correctly. + """Test that query variables are parsed and found correctly. Args: query: a valid GraphQL query (using our schema) variables: map with variable values for the query - expected_variables: expected parsed variables - expected_error: expected error, if any + search_arg: argument and value to search for + expected_result: was the argument and value found """ - def test(): - """Inner function to avoid duplication in if/else""" - document = parse(query) - document_arguments = AstDocArguments( - schema=schema, - document_ast=document, - variable_values=variables - ) - parsed_variables = next( - iter( - document_arguments.operation_defs.values() - ) - )['variables'] - assert expected_variables == parsed_variables - if expected_error is not None: - with pytest.raises(expected_error): - test() - else: - test() + document = parse(query) + type_info = TypeInfo(schema.graphql_schema) + cylc_visitor = CylcVisitor( + type_info, + variables, + search_arg + ) + visit( + document, + TypeInfoVisitor( + type_info, + cylc_visitor + ), + None + ) + + assert expected_result == cylc_visitor.arg_flag @pytest.mark.parametrize( @@ -159,6 +179,42 @@ def test_null_setter(pre_result, expected_result): assert post_result == expected_result +@pytest.mark.parametrize( + 'pre_result,' + 'expected_result', + [ + ( + 'foo', + 'foo' + ), + ( + [NULL_VALUE], + [] + ), + ( + {'nothing': NULL_VALUE}, + {}, + ), + ( + TASK_PROXY_PREREQS.prerequisites, + TASK_PROXY_PREREQS.prerequisites + ), + ] +) +async def test_strip_null(pre_result, expected_result): + """Test the null stripping of different result data/types.""" + # non-async + post_result = async_next(strip_null, pre_result) + assert post_result == expected_result + + async def async_result(result): + return result + + # async + async_post_result = async_next(strip_null, async_result(pre_result)) + assert await async_post_result == expected_result + + @pytest.mark.parametrize( 'expect, tree, path, leaves', [ diff --git a/tests/unit/network/test_schema.py b/tests/unit/network/test_schema.py index 1604cadfb0a..c6cf7ce8d84 100644 --- a/tests/unit/network/test_schema.py +++ b/tests/unit/network/test_schema.py @@ -26,11 +26,12 @@ RUNTIME_FIELD_TO_CFG_MAP, Mutations, Runtime, + WorkflowStopMode, runtime_schema_to_cfg, sort_elements, SortArgs, ) -from cylc.flow.workflow_status import WorkflowStatus +from cylc.flow.workflow_status import StopMode, WorkflowStatus @dataclass @@ -138,3 +139,10 @@ def test_mutations_valid_for(mutation): valid_states = set(match.group(1).split(', ')) assert valid_states assert not valid_states.difference(i.value for i in WorkflowStatus) + + +@pytest.mark.parametrize('wflow_stop_mode', list(WorkflowStopMode)) +def test_stop_mode_enum(wflow_stop_mode): + """Check that WorkflowStopMode is a subset of StopMode.""" + assert StopMode(wflow_stop_mode.value) + assert wflow_stop_mode.description diff --git a/tests/unit/test_links.py b/tests/unit/test_links.py index 9f8a228ad68..2adfb8b284f 100644 --- a/tests/unit/test_links.py +++ b/tests/unit/test_links.py @@ -27,7 +27,8 @@ import re from time import sleep import pytest -import urllib +from urllib import request +from urllib.error import HTTPError EXCLUDE = [ r'*//www.gnu.org/licenses/', @@ -60,13 +61,13 @@ def test_embedded_url(link): to run in parallel """ try: - urllib.request.urlopen(link).getcode() - except urllib.error.HTTPError: + request.urlopen(link).getcode() + except HTTPError: # Sleep and retry to reduce risk of flakiness: sleep(10) try: - urllib.request.urlopen(link).getcode() - except urllib.error.HTTPError as exc: + request.urlopen(link).getcode() + except HTTPError as exc: # Allowing 403 - just because a site forbids us doesn't mean the # link is wrong. if exc.code != 403: diff --git a/tox.ini b/tox.ini index 320bd721de3..5f27ab14cd7 100644 --- a/tox.ini +++ b/tox.ini @@ -31,6 +31,7 @@ exclude= .git, __pycache__, .tox, + **graphql_subscribe.py, **data_messages_pb2.py paths = ./cylc/flow