diff --git a/.coveragerc b/.coveragerc index cb452fd6..9ddccdc6 100644 --- a/.coveragerc +++ b/.coveragerc @@ -28,7 +28,6 @@ disable_warnings= omit = venv/* cylc/uiserver/tests/* - cylc/uiserver/websockets/* cylc/uiserver/jupyter*_config.py parallel = True plugins= @@ -61,7 +60,6 @@ ignore_errors = False omit = venv/* cylc/uiserver/tests/* - cylc/uiserver/websockets/* cylc/uiserver/jupyter*_config.py precision=2 show_missing=False diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index 9b6b0fd1..e4b243e9 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -27,7 +27,7 @@ jobs: fail-fast: false matrix: os: ['ubuntu-latest'] - python-version: ['3.8', '3.9'] + python-version: ['3', '3.8', '3.9'] include: - os: 'macos-latest' python-version: '3.8' # oldest supported diff --git a/changes.d/672.feat.md b/changes.d/672.feat.md new file mode 100644 index 00000000..f8bdaf82 --- /dev/null +++ b/changes.d/672.feat.md @@ -0,0 +1 @@ +Major version upgrade for graphene/graphql-core dependencies. Removed the graphene-tornado and graphql-ws dependencies which had blocked Python 3.10 adoption. diff --git a/cylc/uiserver/app.py b/cylc/uiserver/app.py index 97413fa1..13761776 100644 --- a/cylc/uiserver/app.py +++ b/cylc/uiserver/app.py @@ -69,11 +69,6 @@ Union, ) -from cylc.flow.network.graphql import ( - CylcGraphQLBackend, - IgnoreFieldMiddleware, -) -from cylc.flow.profiler import Profiler from jupyter_server.extension.application import ExtensionApp from packaging.version import Version from tornado import ioloop @@ -92,6 +87,10 @@ ) from traitlets.config.loader import LazyConfigValue +from cylc.flow.network.graphql import ( + CylcExecutionContext, IgnoreFieldMiddleware +) +from cylc.flow.profiler import Profiler from cylc.uiserver import __file__ as uis_pkg from cylc.uiserver.authorise import ( Authorization, @@ -112,7 +111,7 @@ ) from cylc.uiserver.resolvers import Resolvers from cylc.uiserver.schema import schema -from cylc.uiserver.websockets.tornado import TornadoSubscriptionServer +from cylc.uiserver.graphql.tornado_ws import TornadoSubscriptionServer from cylc.uiserver.workflows_mgr import WorkflowsManager @@ -513,11 +512,11 @@ def initialize_handlers(self): { 'schema': schema, 'resolvers': self.resolvers, - 'backend': CylcGraphQLBackend(), 'middleware': [ AuthorizationMiddleware, IgnoreFieldMiddleware ], + 'execution_context_class': CylcExecutionContext, 'auth': self.authobj, } ), @@ -527,11 +526,11 @@ def initialize_handlers(self): { 'schema': schema, 'resolvers': self.resolvers, - 'backend': CylcGraphQLBackend(), 'middleware': [ AuthorizationMiddleware, IgnoreFieldMiddleware ], + 'execution_context_class': CylcExecutionContext, 'batch': True, 'auth': self.authobj, } @@ -571,11 +570,11 @@ def initialize_handlers(self): def set_sub_server(self): self.subscription_server = TornadoSubscriptionServer( schema, - backend=CylcGraphQLBackend(), middleware=[ IgnoreFieldMiddleware, AuthorizationMiddleware, ], + execution_context_class=CylcExecutionContext, auth=self.authobj, ) diff --git a/cylc/uiserver/authorise.py b/cylc/uiserver/authorise.py index 9333c845..c047b967 100644 --- a/cylc/uiserver/authorise.py +++ b/cylc/uiserver/authorise.py @@ -17,11 +17,11 @@ from functools import lru_cache from getpass import getuser import grp -from inspect import iscoroutinefunction import os from typing import List, Optional, Union, Set, Tuple import graphene +from graphql.pyutils import is_awaitable from jupyter_server.auth import Authorizer from tornado import web @@ -508,10 +508,15 @@ class AuthorizationMiddleware: def resolve(self, next_, root, info, **args): current_user = info.context["current_user"] - # We won't be re-checking auth for return variables - if len(info.path) > 1: + # The resolving starts at the top of the path, so only the first + # entry is guarded, and any subsequent fields do not need to be + # checked. + if len(info.path.as_list()) > 1: return next_(root, info, **args) - op_name = self.get_op_name(info.field_name, info.operation.operation) + op_name = self.get_op_name( + info.field_name, + info.operation.operation.value + ) # It shouldn't get here but worth checking for zero trust if not op_name: self.auth_failed( @@ -527,12 +532,7 @@ def resolve(self, next_, root, info, **args): authorised = False if not authorised: self.auth_failed(current_user, op_name, http_code=403) - if ( - info.operation.operation in Authorization.ASYNC_OPS - or iscoroutinefunction(next_) - ): - return self.async_resolve(next_, root, info, **args) - return next_(root, info, **args) + return self.async_resolve(next_, root, info, **args) def auth_failed( self, @@ -588,7 +588,10 @@ def get_op_name(self, field_name: str, operation: str) -> Optional[str]: async def async_resolve(self, next_, root, info, **args): """Return awaited coroutine""" - return await next_(root, info, **args) + result = next_(root, info, **args) + if is_awaitable(result): + return await result + return result def get_groups(username: str) -> Tuple[List[str], List[str]]: diff --git a/cylc/uiserver/websockets/__init__.py b/cylc/uiserver/graphql/__init__.py similarity index 96% rename from cylc/uiserver/websockets/__init__.py rename to cylc/uiserver/graphql/__init__.py index f5633a1b..09dfd0d9 100644 --- a/cylc/uiserver/websockets/__init__.py +++ b/cylc/uiserver/graphql/__init__.py @@ -13,7 +13,7 @@ # You should have received a copy of the GNU General Public License # along with this program. If not, see . -"""Websockets and subscriptions related code.""" +"""GraphQL, Websockets and subscriptions related code.""" from typing import ( Awaitable, diff --git a/cylc/uiserver/graphql/tornado.py b/cylc/uiserver/graphql/tornado.py new file mode 100644 index 00000000..40e12c17 --- /dev/null +++ b/cylc/uiserver/graphql/tornado.py @@ -0,0 +1,492 @@ +# The MIT License (MIT) +# +# Copyright (c) 2016-Present Syrus Akbary +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# ---------------------------------------------------------------------------- +# +# This is a Cylc port of the graphene-django graphene integration: +# https://github.com/graphql-python/graphene-django +# with reference to: +# https://github.com/graphql-python/graphene-tornado +# https://github.com/graphql-python/graphql-server +# +# Excludes GraphiQL + +from asyncio import iscoroutinefunction +import json +import re +import sys +import traceback +from typing import ( + TYPE_CHECKING, + Any, + Callable, + Dict, + List, + Optional, + Tuple, + Union, +) + +from tornado import web +from tornado.escape import json_encode +from tornado.escape import to_unicode +from tornado.httpclient import HTTPClientError +from tornado.log import app_log +from tornado.web import HTTPError + + +from graphql import ( + DocumentNode, + ExecutionResult, + OperationType, + execute, + get_operation_ast, + parse, + validate_schema, +) +from graphql.error import GraphQLError +from graphql.execution.middleware import MiddlewareManager +from graphql.pyutils import is_awaitable +from graphql.validation import validate + +from cylc.flow.network.graphql import ( + NULL_VALUE, + instantiate_middleware, + strip_null +) + +if TYPE_CHECKING: + from graphene import Schema + from tornado.httputil import HTTPServerRequest + +MUTATION_ERRORS_FLAG = "graphene_mutation_has_errors" +MAX_VALIDATION_ERRORS = None + + +def data_search_action(data, action): + if isinstance(data, dict): + return { + key: data_search_action(val, action) + for key, val in data.items() + } + if isinstance(data, list): + return [ + data_search_action(val, action) + for val in data + ] + return action(data) + + +def get_content_type(request: 'HTTPServerRequest') -> str: + return request.headers.get("Content-Type", "").split(";", 1)[0].lower() + + +def get_accepted_content_types(request: 'HTTPServerRequest') -> list: + def qualify(x): + parts = x.split(";", 1) + if len(parts) == 2: + match = re.match( + r"(^|;)q=(0(\.\d{,3})?|1(\.0{,3})?)(;|$)", parts[1]) + if match: + return parts[0].strip(), float(match.group(2)) + return parts[0].strip(), 1 + + raw_content_types = request.headers.get("Accept", "*/*").split(",") + qualified_content_types = map(qualify, raw_content_types) + return [ + x[0] + for x in sorted( + qualified_content_types, key=lambda x: x[1], reverse=True) + ] + + +class ExecutionError(Exception): + def __init__(self, status_code=400, errors=None): + self.status_code = status_code + if errors is None: + self.errors = [] + else: + self.errors = [str(e) for e in errors] + self.message = "\n".join(self.errors) + + +class TornadoGraphQLHandler(web.RequestHandler): + + document: Optional[DocumentNode] + graphql_params: Optional[Tuple[Any, Any, Any, Any]] + middleware: Optional[Union[MiddlewareManager, List[Callable], None]] + parsed_body: Optional[Dict[str, Any]] + + def initialize( + self, + schema: 'Schema', + middleware: Union[MiddlewareManager, List[type], None] = None, + root_value=None, + pretty: bool = False, + batch: bool = False, + subscription_path=None, + execution_context_class=None, + validation_rules=None, + ) -> None: + super(TornadoGraphQLHandler, self).initialize() + self.schema = schema + self.root_value = root_value + self.pretty = pretty + self.batch = batch + self.subscription_path = subscription_path + self.execution_context_class = execution_context_class + self.validation_rules = validation_rules + + self.graphql_params = None + self.parsed_body = None + + if isinstance(middleware, MiddlewareManager): + self.middleware = middleware + elif middleware is not None: + self.middleware = list(instantiate_middleware(middleware)) + else: + self.middleware = None + + def get_context(self): + return self.request + + def get_root_value(self): + return self.root_value + + def get_middleware(self) -> Union[MiddlewareManager, List[Callable], None]: + return self.middleware + + def get_parsed_body(self): + return self.parsed_body + + async def get(self) -> None: + try: + await self.run("get") + except Exception as ex: + self.handle_error(ex) + + async def post(self) -> None: + try: + await self.run("post") + except Exception as ex: + self.handle_error(ex) + + async def run(self, *args, **kwargs): + try: + data = self.parse_body() + + if self.batch: + responses = [ + await self.get_response(entry) + for entry in data + ] + result = "[{}]".format( + ",".join([response[0] for response in responses]) + ) + status_code = ( + responses + and max(responses, key=lambda response: response[1])[1] + or 200 + ) + else: + result, status_code = await self.get_response(data) + + self.set_status(status_code) + self.set_header("Content-Type", "application/json") + self.write(result) + await self.finish() + + except HTTPClientError as e: + response = e.response + response["Content-Type"] = "application/json" + response.content = self.json_encode( + self.request, {"errors": [self.format_error(e)]} + ) + return response + + async def get_response(self, data): + query, variables, operation_name, _id = self.get_graphql_params( + self.request, data + ) + + execution_result = await self.execute_graphql_request( + data, query, variables, operation_name + ) + + status_code = 200 + if execution_result: + response = {} + + if is_awaitable(execution_result) or iscoroutinefunction( + execution_result + ): + execution_result = await execution_result + + if hasattr(execution_result, "get"): + execution_result = execution_result.get() + + if execution_result.errors: + response["errors"] = [ + self.format_error(e) for e in execution_result.errors + ] + + if execution_result.errors and any( + not getattr(e, "path", None) for e in execution_result.errors + ): + status_code = 400 + else: + response["data"] = execution_result.data + + if self.batch: + response["id"] = _id + response["status"] = status_code + try: + result = self.json_encode(response) + except TypeError: + # Catch exceptions in response + errors = [] + + def exc_to_errors(data): + if isinstance(data, Exception): + errors.append({ + 'message': ( + f'{data.value}' + if hasattr(data, 'value') else f'{data}' + ) + }) + return NULL_VALUE + return data + + response = data_search_action( + response, + exc_to_errors + ) + response.setdefault("errors", []).extend(errors) + response = strip_null(response) + + result = self.json_encode(response) + else: + result = None + + return result, status_code + + def json_encode(self, d, pretty=False): + if (self.pretty or pretty) or self.get_query_argument("pretty", False): + return json.dumps( + d, sort_keys=True, indent=2, separators=(",", ": ")) + + return json.dumps(d, separators=(",", ":")) + + def parse_body(self): + content_type = get_content_type(self.request) + + if content_type == "application/graphql": + self.parsed_body = {"query": to_unicode(self.request.body)} + return self.parsed_body + + elif content_type == "application/json": + try: + body = self.request.body + except Exception as e: + raise ExecutionError(400, e) + + try: + request_json = json.loads(body) + if self.batch: + if not isinstance(request_json, list): + raise AssertionError( + "Batch requests should receive a list" + ", but received {}." + ).format(repr(request_json)) + if len(request_json <= 0): + raise AssertionError( + "Received an empty list in the batch request." + ) + else: + if not isinstance(request_json, dict): + raise AssertionError( + "The received data is not a valid JSON query." + ) + self.parsed_body = request_json + return self.parsed_body + except AssertionError as e: + raise HTTPError(status_code=400, log_message=str(e)) + except (TypeError, ValueError): + raise HTTPError( + status_code=400, log_message="POST body sent invalid JSON." + ) + + elif content_type in [ + "application/x-www-form-urlencoded", + "multipart/form-data", + ]: + self.parsed_body = self.request.query_arguments + return self.parsed_body + + self.parsed_body = {} + return self.parsed_body + + async def execute_graphql_request( + self, data, query, variables, operation_name + ): + if not query: + raise HTTPError( + status_code=400, log_message="Must provide query string." + ) + + schema = self.schema.graphql_schema + + schema_validation_errors = validate_schema(schema) + if schema_validation_errors: + return ExecutionResult(data=None, errors=schema_validation_errors) + + try: + self.document = parse(query) + except Exception as e: + return ExecutionResult(errors=[e]) + + operation_ast = get_operation_ast(self.document, operation_name) + + if ( + self.request.method.lower() == "get" + and operation_ast is not None + and operation_ast.operation != OperationType.QUERY + ): + raise HTTPError( + status_code=405, + log_message=( + f'Can only perform a {operation_ast.operation.value} ' + 'operation from a POST request.' + ), + ) + + validation_errors = validate( + schema, + self.document, + self.validation_rules, + MAX_VALIDATION_ERRORS, + ) + if validation_errors: + return ExecutionResult(data=None, errors=validation_errors) + + try: + execute_options = { + "root_value": self.get_root_value(), + "context_value": self.get_context(), + "variable_values": variables, + "operation_name": operation_name, + "middleware": self.get_middleware(), + } + if self.execution_context_class: + execute_options[ + "execution_context_class" + ] = self.execution_context_class + + result = await self.execute( + schema, + self.document, + **execute_options + ) + + return result + except Exception as e: + return ExecutionResult(errors=[e]) + + async def execute(self, *args, **kwargs): + return execute(*args, **kwargs) + + def request_wants_html(self): + accepted = get_accepted_content_types(self.request) + accepted_length = len(accepted) + # the list will be ordered in preferred first - so we have to make + # sure the most preferred gets the highest number + html_priority = ( + accepted_length - accepted.index("text/html") + if "text/html" in accepted + else 0 + ) + json_priority = ( + accepted_length - accepted.index("application/json") + if "application/json" in accepted + else 0 + ) + + return html_priority > json_priority + + def get_graphql_params(self, request, data): + if self.graphql_params: + return self.graphql_params + + single_args = {} + for key in request.arguments.keys(): + single_args[key] = self.decode_argument( + request.arguments.get(key)[0]) + + query = single_args.get("query") or data.get("query") + variables = single_args.get("variables") or data.get("variables") + _id = single_args.get("id") or data.get("id") + + if variables and isinstance(variables, str): + try: + variables = json.loads(variables) + except Exception: + raise HTTPError( + status_code=400, + log_message="Variables are invalid JSON." + ) + + operation_name = ( + single_args.get("operationName") or data.get("operationName") + ) + if operation_name == "null": + operation_name = None + + self.graphql_params = query, variables, operation_name, _id + return self.graphql_params + + def handle_error(self, ex: Exception) -> None: + if not isinstance(ex, (web.HTTPError, ExecutionError, GraphQLError)): + tb = "".join(traceback.format_exception(*sys.exc_info())) + app_log.error("Error: {0} {1}".format(ex, tb)) + self.set_status(self.error_status(ex)) + error_json = json_encode({"errors": self.error_format(ex)}) + app_log.debug("error_json: %s", error_json) + self.write(error_json) + + @staticmethod + def error_status(exception: Exception) -> int: + if isinstance(exception, web.HTTPError): + return exception.status_code + elif isinstance(exception, (ExecutionError, GraphQLError)): + return 400 + else: + return 500 + + @staticmethod + def error_format(exception: Exception) -> List[Dict[str, Any]]: + if isinstance(exception, ExecutionError): + return [{"message": e} for e in exception.errors] + elif isinstance(exception, GraphQLError): + return [{"message": exception.formatted["message"]}] + elif isinstance(exception, web.HTTPError): + return [{"message": exception.log_message}] + else: + return [{"message": "Unknown server error"}] + + @staticmethod + def format_error(error): + if isinstance(error, GraphQLError): + return error.formatted + + return {"message": str(error)} diff --git a/cylc/uiserver/graphql/tornado_ws.py b/cylc/uiserver/graphql/tornado_ws.py new file mode 100644 index 00000000..45b2514c --- /dev/null +++ b/cylc/uiserver/graphql/tornado_ws.py @@ -0,0 +1,398 @@ +# The MIT License (MIT) +# +# Copyright (c) 2016-Present Syrus Akbary +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# ---------------------------------------------------------------------------- +# +# This code was derived from coded in the graphql-ws project, and code that +# was offered to the graphql-ws project by members of the Cylc +# development team but not merged. +# +# * https://github.com/graphql-python/graphql-ws +# * https://github.com/graphql-python/graphql-ws/pull/25/files +# +# It has been evolved to suit and ported to graphql-core v3. + +import asyncio +from asyncio.queues import QueueEmpty +from contextlib import suppress +from inspect import isawaitable +import json +from weakref import WeakSet + +from tornado.websocket import WebSocketClosedError +from graphql import ( + parse, + validate, + ExecutionResult, + GraphQLError, + MiddlewareManager, +) +from graphql.pyutils import is_awaitable + +from cylc.flow.network.graphql import instantiate_middleware +from cylc.flow.network.graphql_subscribe import subscribe + +from cylc.uiserver.authorise import AuthorizationMiddleware +from cylc.uiserver.schema import SUB_RESOLVER_MAPPING + + +NO_MSG_DELAY = 1.0 + +GRAPHQL_WS = "graphql-ws" +WS_PROTOCOL = GRAPHQL_WS +GQL_CONNECTION_INIT = "connection_init" # Client -> Server +GQL_CONNECTION_ACK = "connection_ack" # Server -> Client +GQL_CONNECTION_ERROR = "connection_error" # Server -> Client +GQL_CONNECTION_TERMINATE = "connection_terminate" # Client -> Server +GQL_CONNECTION_KEEP_ALIVE = "ka" # Server -> Client +GQL_START = "start" # Client -> Server +GQL_DATA = "data" # Server -> Client +GQL_ERROR = "error" # Server -> Client +GQL_COMPLETE = "complete" # Server -> Client +GQL_STOP = "stop" # Client -> Server + + +class ConnectionClosedException(Exception): + pass + + +class TornadoConnectionContext: + + def __init__(self, ws, request_context=None): + self.ws = ws + self.operations = {} + self.request_context = request_context + self.pending_tasks = WeakSet() + + def has_operation(self, op_id): + return op_id in self.operations + + def register_operation(self, op_id, async_iterator): + self.operations[op_id] = async_iterator + + def get_operation(self, op_id): + return self.operations[op_id] + + def remove_operation(self, op_id): + try: + return self.operations.pop(op_id) + except KeyError: + return + + async def receive(self): + try: + return self.ws.recv_nowait() + except WebSocketClosedError: + raise ConnectionClosedException() + + async def send(self, data): + if self.closed: + return + await self.ws.write_message(data) + + @property + def closed(self): + return self.ws.close_code is not None + + async def close(self, code): + await self.ws.close(code) + + def remember_task(self, task): + self.pending_tasks.add(task) + # Clear completed tasks + self.pending_tasks -= WeakSet( + task for task in self.pending_tasks if task.done() + ) + + async def unsubscribe(self, op_id): + async_iterator = self._unsubscribe(op_id) + if ( + getattr(async_iterator, "future", None) + and async_iterator.future.cancel() + ): + await async_iterator.future + + def _unsubscribe(self, op_id): + async_iterator = self.remove_operation(op_id) + if hasattr(async_iterator, "dispose"): + async_iterator.dispose() + return async_iterator + + async def unsubscribe_all(self): + awaitables = [ + self.unsubscribe(op_id) + for op_id in list(self.operations) + ] + for task in self.pending_tasks: + task.cancel() + awaitables.append(task) + if awaitables: + with suppress(asyncio.CancelledError): + await asyncio.gather(*awaitables) + + +class TornadoSubscriptionServer: + + def __init__( + self, + schema, + keep_alive=True, + loop=None, + middleware=None, + execution_context_class=None, + auth=None + ): + self.schema = schema + self.loop = loop + self.middleware = middleware + self.execution_context_class = execution_context_class + self.auth = auth + + async def execute(self, params): + # Parse query to document + try: + document = parse(params['query']) + except GraphQLError as error: + return ExecutionResult(data=None, errors=[error]) + + # Validate document against schema + validation_errors = validate(self.schema.graphql_schema, document) + if validation_errors: + return ExecutionResult(data=None, errors=validation_errors) + + # execute subscription + return await subscribe( + self.schema.graphql_schema, + document, + **params['kwargs'] + ) + + def process_message(self, connection_context, parsed_message): + task = asyncio.ensure_future( + self._process_message(connection_context, parsed_message), + loop=self.loop + ) + connection_context.remember_task(task) + return task + + async def _process_message(self, connection_context, parsed_message): + op_id = parsed_message.get("id") + op_type = parsed_message.get("type") + payload = parsed_message.get("payload") + + if op_type == GQL_CONNECTION_INIT: + return await self.on_connection_init( + connection_context, op_id, payload + ) + + elif op_type == GQL_CONNECTION_TERMINATE: + return self.on_connection_terminate(connection_context, op_id) + + elif op_type == GQL_START: + if not isinstance(payload, dict): + raise AssertionError("The payload must be a dict") + params = self.get_graphql_params(connection_context, payload) + return await self.on_start(connection_context, op_id, params) + + elif op_type == GQL_STOP: + return await self.on_stop(connection_context, op_id) + + else: + return await self.send_error( + connection_context, + op_id, + Exception("Invalid message type: {}.".format(op_type)), + ) + + async def on_connection_init(self, connection_context, op_id, payload): + try: + await self.on_connect(connection_context, payload) + await self.send_message( + connection_context, op_type=GQL_CONNECTION_ACK) + except Exception as e: + await self.send_error( + connection_context, op_id, e, GQL_CONNECTION_ERROR) + await connection_context.close(1011) + + async def on_connect(self, connection_context, payload): + pass + + def on_connection_terminate(self, connection_context, op_id): + return connection_context.close(1011) + + def get_graphql_params(self, connection_context, payload): + # Create a new context object for each subscription, + # which allows it to carry a unique subscription id. + params = { + "variable_values": payload.get("variables"), + "operation_name": payload.get("operationName"), + "context_value": dict( + payload.get("context", connection_context.request_context) + ), + "subscribe_resolver_map": SUB_RESOLVER_MAPPING, + } + # If middleware get instantiated here (optional), they will + # be local/private to each subscription. + if self.middleware is not None: + middleware = list( + instantiate_middleware(self.middleware) + ) + else: + middleware = self.middleware + for mw in self.middleware: + if mw == AuthorizationMiddleware: + mw.auth = self.auth + return { + 'query': payload.get("query"), + 'kwargs': dict( + params, + middleware=MiddlewareManager( + *middleware, + ), + execution_context_class=self.execution_context_class, + ) + } + + async def on_open(self, connection_context): + pass + + async def on_stop(self, connection_context, op_id): + return await connection_context.unsubscribe(op_id) + + async def on_close(self, connection_context): + return await connection_context.unsubscribe_all() + + async def handle(self, ws, request_context=None): + await asyncio.shield(self._handle(ws, request_context)) + + async def _handle(self, ws, request_context=None): + connection_context = TornadoConnectionContext(ws, request_context) + await self.on_open(connection_context) + while True: + message = None + try: + if connection_context.closed: + raise ConnectionClosedException() + message = await connection_context.receive() + except QueueEmpty: + pass + except ConnectionClosedException: + break + if message: + await self.on_message(connection_context, message) + else: + await asyncio.sleep(NO_MSG_DELAY) + + await self.on_close(connection_context) + + async def on_start(self, connection_context, op_id, params): + # Attempt to unsubscribe first in case we already have a subscription + # with this id. + await connection_context.unsubscribe(op_id) + + params['kwargs']['root_value'] = op_id + execution_result = await self.execute(params) + iterator = None + try: + if isawaitable(execution_result): + execution_result = await execution_result + if not hasattr(execution_result, '__aiter__'): + await self.send_execution_result( + connection_context, op_id, execution_result) + else: + iterator = execution_result.__aiter__() + connection_context.register_operation(op_id, iterator) + async for single_result in iterator: + if not connection_context.has_operation(op_id): + break + await self.send_execution_result( + connection_context, op_id, single_result) + except (GeneratorExit, asyncio.CancelledError): + raise + except Exception as e: + await self.send_error(connection_context, op_id, e) + finally: + if iterator: + await iterator.aclose() + await self.send_message(connection_context, op_id, GQL_COMPLETE) + await connection_context.unsubscribe(op_id) + await self.on_operation_complete(connection_context, op_id) + + async def send_message( + self, connection_context, op_id=None, op_type=None, payload=None + ): + message = self.build_message(op_id, op_type, payload) + return await connection_context.send(message) + + def build_message(self, _id, op_type, payload): + message = {} + if _id is not None: + message["id"] = _id + if op_type is not None: + message["type"] = op_type + if payload is not None: + message["payload"] = payload + if not message: + raise AssertionError("You need to send at least one thing") + return message + + async def send_execution_result( + self, connection_context, op_id, execution_result): + # Resolve any pending promises + if is_awaitable(execution_result.data): + await execution_result.data + if execution_result.data and 'logs' not in execution_result.data: + request_context = connection_context.request_context + await request_context['resolvers'].flow_delta_processed( + request_context, op_id) + + result = execution_result.formatted + return await self.send_message( + connection_context, op_id, GQL_DATA, result + ) + + async def on_operation_complete(self, connection_context, op_id): + # remove the subscription from the sub_statuses dict + with suppress(KeyError): + connection_context.request_context['sub_statuses'].pop(op_id) + + async def send_error( + self, connection_context, op_id, error, error_type=None + ): + if error_type is None: + error_type = GQL_ERROR + + if error_type not in {GQL_CONNECTION_ERROR, GQL_ERROR}: + raise AssertionError( + "error_type should be one of the allowed error messages" + " GQL_CONNECTION_ERROR or GQL_ERROR" + ) + + error_payload = {"message": str(error)} + + return await self.send_message( + connection_context, op_id, error_type, error_payload) + + async def on_message(self, connection_context, message): + try: + if not isinstance(message, dict): + parsed_message = json.loads(message) + if not isinstance(parsed_message, dict): + raise AssertionError("Payload must be an object.") + else: + parsed_message = message + except Exception as e: + return await self.send_error(connection_context, None, e) + + return self.process_message(connection_context, parsed_message) diff --git a/cylc/uiserver/handlers.py b/cylc/uiserver/handlers.py index 5809f48b..4238a198 100644 --- a/cylc/uiserver/handlers.py +++ b/cylc/uiserver/handlers.py @@ -19,16 +19,9 @@ import json import os import re -from typing import ( - TYPE_CHECKING, - Callable, - Dict, -) +from typing import TYPE_CHECKING, Callable, Dict, Awaitable, Optional from cylc.flow import __version__ as cylc_flow_version -from graphene_tornado.tornado_graphql_handler import TornadoGraphQLHandler -from graphql import get_default_backend -from graphql_ws.constants import GRAPHQL_WS from jupyter_server.base.handlers import JupyterHandler from tornado import ( web, @@ -37,20 +30,17 @@ from tornado.ioloop import IOLoop from cylc.uiserver import __version__ -from cylc.uiserver.authorise import ( - Authorization, - AuthorizationMiddleware, -) +from cylc.uiserver.authorise import Authorization, AuthorizationMiddleware +from cylc.uiserver.graphql import authenticated as websockets_authenticated +from cylc.uiserver.graphql.tornado import TornadoGraphQLHandler +from cylc.uiserver.graphql.tornado_ws import GRAPHQL_WS from cylc.uiserver.utils import is_bearer_token_authenticated -from cylc.uiserver.websockets import authenticated as websockets_authenticated if TYPE_CHECKING: - from graphql.execution import ExecutionResult - from jupyter_server.auth.identity import User as JPSUser - from cylc.uiserver.resolvers import Resolvers - from cylc.uiserver.websockets.tornado import TornadoSubscriptionServer + from cylc.uiserver.graphql.tornado_ws import TornadoSubscriptionServer + from jupyter_server.auth.identity import User as JPSUser ME = getpass.getuser() @@ -164,7 +154,7 @@ class CylcAppHandler(JupyterHandler): def initialize(self, auth): self.auth = auth - super().initialize() + JupyterHandler.initialize(self) @property def hub_users(self): @@ -322,32 +312,41 @@ class UIServerGraphQLHandler(CylcAppHandler, TornadoGraphQLHandler): def set_default_headers(self) -> None: self.set_header('Server', '') - def initialize(self, schema=None, executor=None, middleware=None, - root_value=None, graphiql=False, pretty=False, - batch=False, backend=None, auth=None, **kwargs): - super(TornadoGraphQLHandler, self).initialize() - self.auth = auth - self.schema = schema + def initialize( + self, + schema, + middleware=None, + root_value=None, + pretty=False, + batch=False, + execution_context_class=None, + validation_rules=None, + auth=None, + **kwargs, + ): + TornadoGraphQLHandler.initialize( + self, + schema, + middleware=middleware, + root_value=root_value, + pretty=pretty, + batch=batch, + execution_context_class=execution_context_class, + validation_rules=validation_rules, + ) + CylcAppHandler.initialize(self, auth) - if middleware is not None: - self.middleware = list(self.instantiate_middleware(middleware)) # Make authorization info available to auth middleware for mw in self.middleware: if isinstance(mw, AuthorizationMiddleware): mw.auth = self.auth - self.executor = executor - self.root_value = root_value - self.pretty = pretty - self.graphiql = graphiql - self.batch = batch - self.backend = backend or get_default_backend() + # Set extra attributes for key, value in kwargs.items(): if hasattr(self, key): setattr(self, key, value) - @property - def context(self): + def get_context(self): """The GraphQL context passed to resolvers (incl middleware).""" return { 'graphql_params': self.graphql_params, @@ -357,15 +356,10 @@ def context(self): } @web.authenticated # type: ignore[arg-type] - async def execute(self, *args, **kwargs) -> 'ExecutionResult': - # Use own backend, and TornadoGraphQLHandler already does validation. - return await self.schema.execute( - *args, - backend=self.backend, - variable_values=kwargs.get('variables'), - validate=False, - **kwargs, - ) + async def execute( + self, *args, **kwargs + ) -> Optional[Awaitable[None]]: + return await TornadoGraphQLHandler.execute(self, *args, **kwargs) @web.authenticated async def run(self, *args, **kwargs): diff --git a/cylc/uiserver/resolvers.py b/cylc/uiserver/resolvers.py index 056d4eaf..f322035b 100644 --- a/cylc/uiserver/resolvers.py +++ b/cylc/uiserver/resolvers.py @@ -16,6 +16,7 @@ """GraphQL resolvers for use in data accessing and mutation of workflows.""" import asyncio +from enum import Enum from contextlib import suppress from copy import deepcopy import errno @@ -42,17 +43,14 @@ Union, ) +from graphql.language import print_ast +import psutil + from cylc.flow.data_store_mgr import WORKFLOW from cylc.flow.exceptions import CylcError from cylc.flow.id import Tokens from cylc.flow.network.resolvers import BaseResolvers -from cylc.flow.scripts.clean import ( - CleanOptions, - run, -) -from graphql.language.base import print_ast -import psutil - +from cylc.flow.scripts.clean import CleanOptions, run if TYPE_CHECKING: from concurrent.futures import Executor @@ -61,7 +59,7 @@ from cylc.flow.data_store_mgr import DataStoreMgr from cylc.flow.option_parsers import Options - from graphql import ResolveInfo + from graphql import GraphQLResolveInfo from cylc.uiserver.app import CylcUIServer from cylc.uiserver.workflows_mgr import WorkflowsManager @@ -142,6 +140,8 @@ def _build_cmd(cmd: List, args: Dict) -> List: if isinstance(value, int) and not isinstance(value, bool): # Any integer items need converting to strings: value = str(value) + elif isinstance(value, Enum): + value = value.value value = [value] for item in value: cmd.append(key) @@ -537,7 +537,7 @@ def __init__( # Mutations async def mutator( self, - info: 'ResolveInfo', + info: 'GraphQLResolveInfo', command: str, w_args: Dict[str, Any], _kwargs: Dict[str, Any], @@ -562,7 +562,8 @@ async def mutator( # Create a modified request string, # containing only the current mutation/field. operation_ast = deepcopy(info.operation) - operation_ast.selection_set.selections = info.field_asts + operation_ast.selection_set.selections = tuple( + n for n in info.field_nodes) graphql_args = { 'request_string': print_ast(operation_ast), @@ -574,12 +575,19 @@ async def mutator( async def service( self, - info: 'ResolveInfo', + info: 'GraphQLResolveInfo', command: str, workflows: Iterable['Tokens'], kwargs: Dict[str, Any], ) -> List[Union[bool, str]]: + # GraphQL v3 includes all variables that are set, even if set to null. + kwargs = { + k: v + for k, v in kwargs.items() + if v is not None + } + if command == 'clean': # noqa: SIM116 return await Services.clean( workflows, @@ -605,7 +613,7 @@ async def service( async def subscription_service( self, - info: 'ResolveInfo', + info: 'GraphQLResolveInfo', _command: str, ids: List[Tokens], file=None @@ -652,7 +660,7 @@ def kill_process_tree( async def list_log_files( root: Optional[Any], - info: 'ResolveInfo', + info: 'GraphQLResolveInfo', id: str, # noqa: required to match schema arg name ): tokens = Tokens(id) @@ -665,7 +673,7 @@ async def list_log_files( async def stream_log( root: Optional[Any], - info: 'ResolveInfo', + info: 'GraphQLResolveInfo', *, command='cat_log', id: str, # noqa: required to match schema arg name diff --git a/cylc/uiserver/schema.py b/cylc/uiserver/schema.py index f1c73378..ef8b3b11 100644 --- a/cylc/uiserver/schema.py +++ b/cylc/uiserver/schema.py @@ -30,6 +30,10 @@ Tuple, ) +import graphene +from graphene.types.generic import GenericScalar +from graphene.types.schema import identity_resolve + from cylc.flow.data_store_mgr import ( JOBS, TASKS, @@ -38,6 +42,7 @@ from cylc.flow.network.schema import ( NODE_MAP, STRIP_NULL_DEFAULT, + SUB_RESOLVER_MAPPING, CyclePoint, GenericResponse, Job, @@ -47,7 +52,7 @@ Subscriptions, Task, WorkflowID, - WorkflowRunMode as RunMode, + WorkflowRunMode, _mut_field, get_nodes_all, process_resolver_info, @@ -64,8 +69,6 @@ ) from cylc.flow.util import sstrip from cylc.flow.workflow_files import WorkflowFiles -import graphene -from graphene.types.generic import GenericScalar from cylc.uiserver.resolvers import ( Resolvers, @@ -75,12 +78,12 @@ if TYPE_CHECKING: - from graphql import ResolveInfo + from graphql import GraphQLResolveInfo async def mutator( root: Optional[Any], - info: 'ResolveInfo', + info: 'GraphQLResolveInfo', *, command: str, workflows: Optional[List[str]] = None, @@ -98,7 +101,11 @@ async def mutator( resolvers: 'Resolvers' = ( info.context.get('resolvers') # type: ignore[union-attr] ) - res = await resolvers.service(info, command, parsed_workflows, kwargs) + try: + res = await resolvers.service(info, command, parsed_workflows, kwargs) + except Exception as exc: + resolvers.log.exception(exc) + raise return GenericResponse(result=res) @@ -167,9 +174,7 @@ class Arguments: Hold all tasks after this cycle point. ''') ) - mode = RunMode( - default_value=RunMode.Live.name - ) + mode = WorkflowRunMode(default_value=WorkflowRunMode.Live) host = graphene.String( description=sstrip(''' Specify the host on which to start-up the workflow. If not @@ -845,6 +850,16 @@ class LogFiles(graphene.ObjectType): ) +# TODO: Change to use subscribe arg/default. +# See https://github.com/cylc/cylc-flow/issues/6688 +# graphql-core has a subscribe field for both Meta and Field, +# graphene at v3.4.3 does not. As a workaround +# the subscribe function is looked up via the following mapping: +SUB_RESOLVER_MAPPING.update({ + 'logs': stream_log, # type: ignore +}) + + class UISSubscriptions(Subscriptions): # Example graphiql workflow log subscription: # subscription { @@ -872,7 +887,7 @@ class Logs(graphene.ObjectType): required=False, description='File name of job log to fetch, e.g. job.out' ), - resolver=stream_log + resolver=identity_resolve ) diff --git a/cylc/uiserver/tests/test_graphql.py b/cylc/uiserver/tests/test_graphql.py index d0b26d0b..7ff0c7d9 100644 --- a/cylc/uiserver/tests/test_graphql.py +++ b/cylc/uiserver/tests/test_graphql.py @@ -161,3 +161,18 @@ async def _log(*args, **kwargs): } } ''').strip() + + # issue clean mutation + response = await gql_query( + *('cylc', 'graphql'), + query=''' + mutation { + clean(workflows: ["%s"]){ + result + } + } + ''' % ( + Tokens(user='me', workflow='foo').id, + ), + ) + assert response.code == 200 diff --git a/cylc/uiserver/tests/test_handlers.py b/cylc/uiserver/tests/test_handlers.py index e07d6066..614550d8 100644 --- a/cylc/uiserver/tests/test_handlers.py +++ b/cylc/uiserver/tests/test_handlers.py @@ -20,11 +20,11 @@ from unittest.mock import MagicMock import pytest -from graphql_ws.constants import GRAPHQL_WS from tornado.httputil import HTTPServerRequest from tornado.testing import AsyncHTTPTestCase, get_async_test_timeout from tornado.web import Application +from cylc.uiserver.graphql.tornado_ws import GRAPHQL_WS from cylc.uiserver.handlers import SubscriptionHandler diff --git a/cylc/uiserver/tests/test_resolvers.py b/cylc/uiserver/tests/test_resolvers.py index 36a41901..026ec1bd 100644 --- a/cylc/uiserver/tests/test_resolvers.py +++ b/cylc/uiserver/tests/test_resolvers.py @@ -16,7 +16,6 @@ import asyncio from concurrent.futures import ThreadPoolExecutor from typing import Any, Dict, List, Tuple -from async_timeout import timeout import logging import os import pytest @@ -24,6 +23,12 @@ from subprocess import Popen, TimeoutExpired from types import SimpleNamespace +import sys +if sys.version_info >= (3, 11): + from asyncio import timeout +else: + from async_timeout import timeout + from cylc.flow import CYLC_LOG from cylc.flow.exceptions import CylcError from cylc.flow.id import Tokens diff --git a/cylc/uiserver/tests/test_schema.py b/cylc/uiserver/tests/test_schema.py index 4c03be36..d1e393c7 100644 --- a/cylc/uiserver/tests/test_schema.py +++ b/cylc/uiserver/tests/test_schema.py @@ -14,9 +14,10 @@ # along with this program. If not, see . import importlib +import pytest import cylc.uiserver.schema -from cylc.uiserver.schema import NODE_MAP as UIS_NODE_MAP +from cylc.uiserver.schema import NODE_MAP as UIS_NODE_MAP, mutator def test_node_map(): @@ -30,3 +31,19 @@ def test_node_map(): _class = getattr(cylc.uiserver.schema, type_name) # It is not straightforward to check that _class is a Graphene class assert 'graphene' in type(_class).__module__ + + +async def test_mutator(cylc_uis): + """Test exception and arg variants.""" + class Info: + context = {'resolvers': cylc_uis.resolvers, 'id': 1} + + with pytest.raises(Exception) as exc: + await mutator( + None, + Info(), + command='NotACommand', + workflows=None, + args={'This': 'That'} + ) + assert 'NotImplementedError' in f'{exc}' diff --git a/cylc/uiserver/websockets/resolve.py b/cylc/uiserver/websockets/resolve.py deleted file mode 100644 index 71bb8f23..00000000 --- a/cylc/uiserver/websockets/resolve.py +++ /dev/null @@ -1,67 +0,0 @@ -# MIT License -# -# Copyright (c) 2017, Syrus Akbary -# -# Permission is hereby granted, free of charge, to any person obtaining a copy -# of this software and associated documentation files (the "Software"), to deal -# in the Software without restriction, including without limitation the rights -# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell -# copies of the Software, and to permit persons to whom the Software is -# furnished to do so, subject to the following conditions: -# -# The above copyright notice and this permission notice shall be included in -# all copies or substantial portions of the Software. -# -# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR -# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, -# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE -# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER -# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, -# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE -# SOFTWARE. - -""" -This file contains an implementation of "resolve" derived from the one -found in the graphql-ws library with the above license. - -This is temporary code until the change makes its way upstream. -""" - -# NOTE: transient dependency from graphql-ws purposefully not -# reflected in cylc-uiserver dependencies -from promise import Promise - -from graphql_ws.base_async import is_awaitable - - -async def resolve( - data, - _container=None, - _key=None, -): - """ - Wait on any awaitable children of a data element and resolve - any Promises. - """ - stack = [(data, _container, _key)] - - while stack: - _data, _container, _key = stack.pop() - - if is_awaitable(_data): - _data = await _data - if isinstance(_data, Promise): - _data = _data.value - if _container is not None: - _container[_key] = _data - if isinstance(_data, dict): - items = _data.items() - elif isinstance(_data, list): - items = enumerate(_data) - else: - items = None - if items is not None: - stack.extend([ - (child, _data, key) - for key, child in items - ]) diff --git a/cylc/uiserver/websockets/tornado.py b/cylc/uiserver/websockets/tornado.py deleted file mode 100644 index aaf2e285..00000000 --- a/cylc/uiserver/websockets/tornado.py +++ /dev/null @@ -1,178 +0,0 @@ -# This file is a temporary solution for subscriptions with graphql_ws and -# Tornado, from the following pending PR to graphql-ws: -# https://github.com/graphql-python/graphql-ws/pull/25/files -# The file was copied from this revision: -# https://github.com/graphql-python/graphql-ws/blob/cf560b9a5d18d4a3908dc2cfe2199766cc988fef/graphql_ws/tornado.py - -from contextlib import suppress -import getpass -from inspect import isawaitable, isclass -import socket - -from asyncio import create_task, gather, wait, shield, sleep -from asyncio.queues import QueueEmpty -from tornado.websocket import WebSocketClosedError -from graphql.execution.middleware import MiddlewareManager -from graphql_ws.base import ConnectionClosedException, BaseSubscriptionServer -from graphql_ws.base_async import ( - BaseAsyncConnectionContext, - BaseAsyncSubscriptionServer -) -from graphql_ws.observable_aiter import setup_observable_extension -from graphql_ws.constants import ( - GQL_CONNECTION_ACK, - GQL_CONNECTION_ERROR, - GQL_COMPLETE -) - -from typing import Union, Awaitable, Any, List, Tuple, Dict, Optional - -from cylc.uiserver.authorise import AuthorizationMiddleware -from cylc.uiserver.websockets.resolve import resolve - - -setup_observable_extension() - -NO_MSG_DELAY = 1.0 - - -class TornadoConnectionContext(BaseAsyncConnectionContext): - async def receive(self): - try: - return self.ws.recv_nowait() - except WebSocketClosedError: - raise ConnectionClosedException() - - async def send(self, data): - if self.closed: - return - await self.ws.write_message(data) - - @property - def closed(self): - return self.ws.close_code is not None - - async def close(self, code): - await self.ws.close(code) - - -class TornadoSubscriptionServer(BaseAsyncSubscriptionServer): - def __init__( - self, schema, - keep_alive=True, - loop=None, - backend=None, - middleware=None, - auth=None - ): - self.loop = loop - self.backend = backend or None - self.middleware = middleware - self.auth = auth - super().__init__(schema, keep_alive) - - @staticmethod - def instantiate_middleware(middlewares): - for middleware in middlewares: - if isclass(middleware): - yield middleware() - continue - yield middleware - - def get_graphql_params(self, *args, **kwargs): - params = super(TornadoSubscriptionServer, - self).get_graphql_params(*args, **kwargs) - # If middleware get instantiated here (optional), they will - # be local/private to each subscription. - if self.middleware is not None: - middleware = list( - self.instantiate_middleware(self.middleware) - ) - else: - middleware = self.middleware - for mw in self.middleware: - if mw == AuthorizationMiddleware: - mw.auth = self.auth - return dict( - params, - return_promise=True, - backend=self.backend, - middleware=MiddlewareManager( - *middleware, - wrap_in_promise=False - ), - ) - - async def _handle(self, ws, request_context=None): - connection_context = TornadoConnectionContext(ws, request_context) - await self.on_open(connection_context) - while True: - message = None - try: - if connection_context.closed: - raise ConnectionClosedException() - message = await connection_context.receive() - except QueueEmpty: - pass - except ConnectionClosedException: - break - if message: - self.on_message(connection_context, message) - else: - await sleep(NO_MSG_DELAY) - - await self.on_close(connection_context) - - async def handle(self, ws, request_context=None): - await shield(self._handle(ws, request_context)) - - async def on_start(self, connection_context, op_id, params): - # Attempt to unsubscribe first in case we already have a subscription - # with this id. - await connection_context.unsubscribe(op_id) - - params['root_value'] = op_id - execution_result = self.execute(params) - try: - if isawaitable(execution_result): - execution_result = await execution_result - if not hasattr(execution_result, '__aiter__'): - await self.send_execution_result( - connection_context, op_id, execution_result) - else: - iterator = await execution_result.__aiter__() - connection_context.register_operation(op_id, iterator) - async for single_result in iterator: - if not connection_context.has_operation(op_id): - break - await self.send_execution_result( - connection_context, op_id, single_result) - except Exception as e: - await self.send_error(connection_context, op_id, e) - await self.send_message(connection_context, op_id, GQL_COMPLETE) - await connection_context.unsubscribe(op_id) - await self.on_operation_complete(connection_context, op_id) - - async def on_operation_complete(self, connection_context, op_id): - # remove the subscription from the sub_statuses dict - with suppress(KeyError): - connection_context.request_context['sub_statuses'].pop(op_id) - - - async def send_execution_result(self, connection_context, op_id, execution_result): - # Resolve any pending promises - if execution_result.data and 'logs' not in execution_result.data: - await resolve(execution_result.data) - request_context = connection_context.request_context - await request_context['resolvers'].flow_delta_processed(request_context, op_id) - else: - await resolve(execution_result.data) - - # NOTE: skip TornadoSubscriptionServer.send_execution_result because it - # calls "resolve" then invokes BaseSubscriptionServer.send_execution_result - await BaseSubscriptionServer.send_execution_result( - self, - connection_context, - op_id, - execution_result, - ) diff --git a/mypy.ini b/mypy.ini index 33ce48d4..9326a592 100644 --- a/mypy.ini +++ b/mypy.ini @@ -3,7 +3,7 @@ python_version = 3.8 ignore_missing_imports = True files = cylc/uiserver # don't run mypy on these files directly -exclude = cylc/uiserver/(tests/|websockets/tornado.py|jupyter_config.py) +exclude = cylc/uiserver/(tests/|jupyter_config.py) # Enable PEP 420 style namespace packages, which we use. # Needed for associating "import foo.bar" with foo/bar.py diff --git a/setup.cfg b/setup.cfg index 9006b989..77370054 100644 --- a/setup.cfg +++ b/setup.cfg @@ -47,13 +47,8 @@ install_requires = # dependencies first. This way, if other dependencies (e.g. jupyterhub) # don't pin versions, we will get whatever cylc-flow needs, and not the # bleeding-edge version. - # NB: no graphene version specified; we only make light use of it in our - # own code, so graphene-tornado's transitive version should do. cylc-flow==8.5.* ansimarkup>=1.0.0 - graphene - graphene-tornado==2.6.* - graphql-ws==0.4.4 importlib-resources>=1.3.0; python_version < "3.9" jupyter_server>=2.7 requests @@ -64,11 +59,7 @@ install_requires = # Transitive dependencies that we directly (lightly) use: pyzmq - graphql-core - - # Fix lack of upper pin for rx in graphql-core<2.3.0 (remove when - # upgrading graphene to 3.0): - rx<2 + graphene [options.packages.find] include = cylc* diff --git a/tox.ini b/tox.ini index 91d4b7fb..9542f66d 100644 --- a/tox.ini +++ b/tox.ini @@ -8,7 +8,6 @@ ignore= W503, ; line break after binary operator W504 - exclude= build, dist, @@ -18,6 +17,5 @@ exclude= .tox, .eggs, cylc/uiserver/jupyter_config.py, - cylc/uiserver/websockets/tornado.py paths = ./cylc/uiserver