diff --git a/.coveragerc b/.coveragerc
index cb452fd6..9ddccdc6 100644
--- a/.coveragerc
+++ b/.coveragerc
@@ -28,7 +28,6 @@ disable_warnings=
omit =
venv/*
cylc/uiserver/tests/*
- cylc/uiserver/websockets/*
cylc/uiserver/jupyter*_config.py
parallel = True
plugins=
@@ -61,7 +60,6 @@ ignore_errors = False
omit =
venv/*
cylc/uiserver/tests/*
- cylc/uiserver/websockets/*
cylc/uiserver/jupyter*_config.py
precision=2
show_missing=False
diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml
index 9b6b0fd1..e4b243e9 100644
--- a/.github/workflows/test.yml
+++ b/.github/workflows/test.yml
@@ -27,7 +27,7 @@ jobs:
fail-fast: false
matrix:
os: ['ubuntu-latest']
- python-version: ['3.8', '3.9']
+ python-version: ['3', '3.8', '3.9']
include:
- os: 'macos-latest'
python-version: '3.8' # oldest supported
diff --git a/changes.d/672.feat.md b/changes.d/672.feat.md
new file mode 100644
index 00000000..f8bdaf82
--- /dev/null
+++ b/changes.d/672.feat.md
@@ -0,0 +1 @@
+Major version upgrade for graphene/graphql-core dependencies. Removed the graphene-tornado and graphql-ws dependencies which had blocked Python 3.10 adoption.
diff --git a/cylc/uiserver/app.py b/cylc/uiserver/app.py
index 97413fa1..13761776 100644
--- a/cylc/uiserver/app.py
+++ b/cylc/uiserver/app.py
@@ -69,11 +69,6 @@
Union,
)
-from cylc.flow.network.graphql import (
- CylcGraphQLBackend,
- IgnoreFieldMiddleware,
-)
-from cylc.flow.profiler import Profiler
from jupyter_server.extension.application import ExtensionApp
from packaging.version import Version
from tornado import ioloop
@@ -92,6 +87,10 @@
)
from traitlets.config.loader import LazyConfigValue
+from cylc.flow.network.graphql import (
+ CylcExecutionContext, IgnoreFieldMiddleware
+)
+from cylc.flow.profiler import Profiler
from cylc.uiserver import __file__ as uis_pkg
from cylc.uiserver.authorise import (
Authorization,
@@ -112,7 +111,7 @@
)
from cylc.uiserver.resolvers import Resolvers
from cylc.uiserver.schema import schema
-from cylc.uiserver.websockets.tornado import TornadoSubscriptionServer
+from cylc.uiserver.graphql.tornado_ws import TornadoSubscriptionServer
from cylc.uiserver.workflows_mgr import WorkflowsManager
@@ -513,11 +512,11 @@ def initialize_handlers(self):
{
'schema': schema,
'resolvers': self.resolvers,
- 'backend': CylcGraphQLBackend(),
'middleware': [
AuthorizationMiddleware,
IgnoreFieldMiddleware
],
+ 'execution_context_class': CylcExecutionContext,
'auth': self.authobj,
}
),
@@ -527,11 +526,11 @@ def initialize_handlers(self):
{
'schema': schema,
'resolvers': self.resolvers,
- 'backend': CylcGraphQLBackend(),
'middleware': [
AuthorizationMiddleware,
IgnoreFieldMiddleware
],
+ 'execution_context_class': CylcExecutionContext,
'batch': True,
'auth': self.authobj,
}
@@ -571,11 +570,11 @@ def initialize_handlers(self):
def set_sub_server(self):
self.subscription_server = TornadoSubscriptionServer(
schema,
- backend=CylcGraphQLBackend(),
middleware=[
IgnoreFieldMiddleware,
AuthorizationMiddleware,
],
+ execution_context_class=CylcExecutionContext,
auth=self.authobj,
)
diff --git a/cylc/uiserver/authorise.py b/cylc/uiserver/authorise.py
index 9333c845..c047b967 100644
--- a/cylc/uiserver/authorise.py
+++ b/cylc/uiserver/authorise.py
@@ -17,11 +17,11 @@
from functools import lru_cache
from getpass import getuser
import grp
-from inspect import iscoroutinefunction
import os
from typing import List, Optional, Union, Set, Tuple
import graphene
+from graphql.pyutils import is_awaitable
from jupyter_server.auth import Authorizer
from tornado import web
@@ -508,10 +508,15 @@ class AuthorizationMiddleware:
def resolve(self, next_, root, info, **args):
current_user = info.context["current_user"]
- # We won't be re-checking auth for return variables
- if len(info.path) > 1:
+ # The resolving starts at the top of the path, so only the first
+ # entry is guarded, and any subsequent fields do not need to be
+ # checked.
+ if len(info.path.as_list()) > 1:
return next_(root, info, **args)
- op_name = self.get_op_name(info.field_name, info.operation.operation)
+ op_name = self.get_op_name(
+ info.field_name,
+ info.operation.operation.value
+ )
# It shouldn't get here but worth checking for zero trust
if not op_name:
self.auth_failed(
@@ -527,12 +532,7 @@ def resolve(self, next_, root, info, **args):
authorised = False
if not authorised:
self.auth_failed(current_user, op_name, http_code=403)
- if (
- info.operation.operation in Authorization.ASYNC_OPS
- or iscoroutinefunction(next_)
- ):
- return self.async_resolve(next_, root, info, **args)
- return next_(root, info, **args)
+ return self.async_resolve(next_, root, info, **args)
def auth_failed(
self,
@@ -588,7 +588,10 @@ def get_op_name(self, field_name: str, operation: str) -> Optional[str]:
async def async_resolve(self, next_, root, info, **args):
"""Return awaited coroutine"""
- return await next_(root, info, **args)
+ result = next_(root, info, **args)
+ if is_awaitable(result):
+ return await result
+ return result
def get_groups(username: str) -> Tuple[List[str], List[str]]:
diff --git a/cylc/uiserver/websockets/__init__.py b/cylc/uiserver/graphql/__init__.py
similarity index 96%
rename from cylc/uiserver/websockets/__init__.py
rename to cylc/uiserver/graphql/__init__.py
index f5633a1b..09dfd0d9 100644
--- a/cylc/uiserver/websockets/__init__.py
+++ b/cylc/uiserver/graphql/__init__.py
@@ -13,7 +13,7 @@
# You should have received a copy of the GNU General Public License
# along with this program. If not, see .
-"""Websockets and subscriptions related code."""
+"""GraphQL, Websockets and subscriptions related code."""
from typing import (
Awaitable,
diff --git a/cylc/uiserver/graphql/tornado.py b/cylc/uiserver/graphql/tornado.py
new file mode 100644
index 00000000..40e12c17
--- /dev/null
+++ b/cylc/uiserver/graphql/tornado.py
@@ -0,0 +1,492 @@
+# The MIT License (MIT)
+#
+# Copyright (c) 2016-Present Syrus Akbary
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in
+# all copies or substantial portions of the Software.
+#
+# ----------------------------------------------------------------------------
+#
+# This is a Cylc port of the graphene-django graphene integration:
+# https://github.com/graphql-python/graphene-django
+# with reference to:
+# https://github.com/graphql-python/graphene-tornado
+# https://github.com/graphql-python/graphql-server
+#
+# Excludes GraphiQL
+
+from asyncio import iscoroutinefunction
+import json
+import re
+import sys
+import traceback
+from typing import (
+ TYPE_CHECKING,
+ Any,
+ Callable,
+ Dict,
+ List,
+ Optional,
+ Tuple,
+ Union,
+)
+
+from tornado import web
+from tornado.escape import json_encode
+from tornado.escape import to_unicode
+from tornado.httpclient import HTTPClientError
+from tornado.log import app_log
+from tornado.web import HTTPError
+
+
+from graphql import (
+ DocumentNode,
+ ExecutionResult,
+ OperationType,
+ execute,
+ get_operation_ast,
+ parse,
+ validate_schema,
+)
+from graphql.error import GraphQLError
+from graphql.execution.middleware import MiddlewareManager
+from graphql.pyutils import is_awaitable
+from graphql.validation import validate
+
+from cylc.flow.network.graphql import (
+ NULL_VALUE,
+ instantiate_middleware,
+ strip_null
+)
+
+if TYPE_CHECKING:
+ from graphene import Schema
+ from tornado.httputil import HTTPServerRequest
+
+MUTATION_ERRORS_FLAG = "graphene_mutation_has_errors"
+MAX_VALIDATION_ERRORS = None
+
+
+def data_search_action(data, action):
+ if isinstance(data, dict):
+ return {
+ key: data_search_action(val, action)
+ for key, val in data.items()
+ }
+ if isinstance(data, list):
+ return [
+ data_search_action(val, action)
+ for val in data
+ ]
+ return action(data)
+
+
+def get_content_type(request: 'HTTPServerRequest') -> str:
+ return request.headers.get("Content-Type", "").split(";", 1)[0].lower()
+
+
+def get_accepted_content_types(request: 'HTTPServerRequest') -> list:
+ def qualify(x):
+ parts = x.split(";", 1)
+ if len(parts) == 2:
+ match = re.match(
+ r"(^|;)q=(0(\.\d{,3})?|1(\.0{,3})?)(;|$)", parts[1])
+ if match:
+ return parts[0].strip(), float(match.group(2))
+ return parts[0].strip(), 1
+
+ raw_content_types = request.headers.get("Accept", "*/*").split(",")
+ qualified_content_types = map(qualify, raw_content_types)
+ return [
+ x[0]
+ for x in sorted(
+ qualified_content_types, key=lambda x: x[1], reverse=True)
+ ]
+
+
+class ExecutionError(Exception):
+ def __init__(self, status_code=400, errors=None):
+ self.status_code = status_code
+ if errors is None:
+ self.errors = []
+ else:
+ self.errors = [str(e) for e in errors]
+ self.message = "\n".join(self.errors)
+
+
+class TornadoGraphQLHandler(web.RequestHandler):
+
+ document: Optional[DocumentNode]
+ graphql_params: Optional[Tuple[Any, Any, Any, Any]]
+ middleware: Optional[Union[MiddlewareManager, List[Callable], None]]
+ parsed_body: Optional[Dict[str, Any]]
+
+ def initialize(
+ self,
+ schema: 'Schema',
+ middleware: Union[MiddlewareManager, List[type], None] = None,
+ root_value=None,
+ pretty: bool = False,
+ batch: bool = False,
+ subscription_path=None,
+ execution_context_class=None,
+ validation_rules=None,
+ ) -> None:
+ super(TornadoGraphQLHandler, self).initialize()
+ self.schema = schema
+ self.root_value = root_value
+ self.pretty = pretty
+ self.batch = batch
+ self.subscription_path = subscription_path
+ self.execution_context_class = execution_context_class
+ self.validation_rules = validation_rules
+
+ self.graphql_params = None
+ self.parsed_body = None
+
+ if isinstance(middleware, MiddlewareManager):
+ self.middleware = middleware
+ elif middleware is not None:
+ self.middleware = list(instantiate_middleware(middleware))
+ else:
+ self.middleware = None
+
+ def get_context(self):
+ return self.request
+
+ def get_root_value(self):
+ return self.root_value
+
+ def get_middleware(self) -> Union[MiddlewareManager, List[Callable], None]:
+ return self.middleware
+
+ def get_parsed_body(self):
+ return self.parsed_body
+
+ async def get(self) -> None:
+ try:
+ await self.run("get")
+ except Exception as ex:
+ self.handle_error(ex)
+
+ async def post(self) -> None:
+ try:
+ await self.run("post")
+ except Exception as ex:
+ self.handle_error(ex)
+
+ async def run(self, *args, **kwargs):
+ try:
+ data = self.parse_body()
+
+ if self.batch:
+ responses = [
+ await self.get_response(entry)
+ for entry in data
+ ]
+ result = "[{}]".format(
+ ",".join([response[0] for response in responses])
+ )
+ status_code = (
+ responses
+ and max(responses, key=lambda response: response[1])[1]
+ or 200
+ )
+ else:
+ result, status_code = await self.get_response(data)
+
+ self.set_status(status_code)
+ self.set_header("Content-Type", "application/json")
+ self.write(result)
+ await self.finish()
+
+ except HTTPClientError as e:
+ response = e.response
+ response["Content-Type"] = "application/json"
+ response.content = self.json_encode(
+ self.request, {"errors": [self.format_error(e)]}
+ )
+ return response
+
+ async def get_response(self, data):
+ query, variables, operation_name, _id = self.get_graphql_params(
+ self.request, data
+ )
+
+ execution_result = await self.execute_graphql_request(
+ data, query, variables, operation_name
+ )
+
+ status_code = 200
+ if execution_result:
+ response = {}
+
+ if is_awaitable(execution_result) or iscoroutinefunction(
+ execution_result
+ ):
+ execution_result = await execution_result
+
+ if hasattr(execution_result, "get"):
+ execution_result = execution_result.get()
+
+ if execution_result.errors:
+ response["errors"] = [
+ self.format_error(e) for e in execution_result.errors
+ ]
+
+ if execution_result.errors and any(
+ not getattr(e, "path", None) for e in execution_result.errors
+ ):
+ status_code = 400
+ else:
+ response["data"] = execution_result.data
+
+ if self.batch:
+ response["id"] = _id
+ response["status"] = status_code
+ try:
+ result = self.json_encode(response)
+ except TypeError:
+ # Catch exceptions in response
+ errors = []
+
+ def exc_to_errors(data):
+ if isinstance(data, Exception):
+ errors.append({
+ 'message': (
+ f'{data.value}'
+ if hasattr(data, 'value') else f'{data}'
+ )
+ })
+ return NULL_VALUE
+ return data
+
+ response = data_search_action(
+ response,
+ exc_to_errors
+ )
+ response.setdefault("errors", []).extend(errors)
+ response = strip_null(response)
+
+ result = self.json_encode(response)
+ else:
+ result = None
+
+ return result, status_code
+
+ def json_encode(self, d, pretty=False):
+ if (self.pretty or pretty) or self.get_query_argument("pretty", False):
+ return json.dumps(
+ d, sort_keys=True, indent=2, separators=(",", ": "))
+
+ return json.dumps(d, separators=(",", ":"))
+
+ def parse_body(self):
+ content_type = get_content_type(self.request)
+
+ if content_type == "application/graphql":
+ self.parsed_body = {"query": to_unicode(self.request.body)}
+ return self.parsed_body
+
+ elif content_type == "application/json":
+ try:
+ body = self.request.body
+ except Exception as e:
+ raise ExecutionError(400, e)
+
+ try:
+ request_json = json.loads(body)
+ if self.batch:
+ if not isinstance(request_json, list):
+ raise AssertionError(
+ "Batch requests should receive a list"
+ ", but received {}."
+ ).format(repr(request_json))
+ if len(request_json <= 0):
+ raise AssertionError(
+ "Received an empty list in the batch request."
+ )
+ else:
+ if not isinstance(request_json, dict):
+ raise AssertionError(
+ "The received data is not a valid JSON query."
+ )
+ self.parsed_body = request_json
+ return self.parsed_body
+ except AssertionError as e:
+ raise HTTPError(status_code=400, log_message=str(e))
+ except (TypeError, ValueError):
+ raise HTTPError(
+ status_code=400, log_message="POST body sent invalid JSON."
+ )
+
+ elif content_type in [
+ "application/x-www-form-urlencoded",
+ "multipart/form-data",
+ ]:
+ self.parsed_body = self.request.query_arguments
+ return self.parsed_body
+
+ self.parsed_body = {}
+ return self.parsed_body
+
+ async def execute_graphql_request(
+ self, data, query, variables, operation_name
+ ):
+ if not query:
+ raise HTTPError(
+ status_code=400, log_message="Must provide query string."
+ )
+
+ schema = self.schema.graphql_schema
+
+ schema_validation_errors = validate_schema(schema)
+ if schema_validation_errors:
+ return ExecutionResult(data=None, errors=schema_validation_errors)
+
+ try:
+ self.document = parse(query)
+ except Exception as e:
+ return ExecutionResult(errors=[e])
+
+ operation_ast = get_operation_ast(self.document, operation_name)
+
+ if (
+ self.request.method.lower() == "get"
+ and operation_ast is not None
+ and operation_ast.operation != OperationType.QUERY
+ ):
+ raise HTTPError(
+ status_code=405,
+ log_message=(
+ f'Can only perform a {operation_ast.operation.value} '
+ 'operation from a POST request.'
+ ),
+ )
+
+ validation_errors = validate(
+ schema,
+ self.document,
+ self.validation_rules,
+ MAX_VALIDATION_ERRORS,
+ )
+ if validation_errors:
+ return ExecutionResult(data=None, errors=validation_errors)
+
+ try:
+ execute_options = {
+ "root_value": self.get_root_value(),
+ "context_value": self.get_context(),
+ "variable_values": variables,
+ "operation_name": operation_name,
+ "middleware": self.get_middleware(),
+ }
+ if self.execution_context_class:
+ execute_options[
+ "execution_context_class"
+ ] = self.execution_context_class
+
+ result = await self.execute(
+ schema,
+ self.document,
+ **execute_options
+ )
+
+ return result
+ except Exception as e:
+ return ExecutionResult(errors=[e])
+
+ async def execute(self, *args, **kwargs):
+ return execute(*args, **kwargs)
+
+ def request_wants_html(self):
+ accepted = get_accepted_content_types(self.request)
+ accepted_length = len(accepted)
+ # the list will be ordered in preferred first - so we have to make
+ # sure the most preferred gets the highest number
+ html_priority = (
+ accepted_length - accepted.index("text/html")
+ if "text/html" in accepted
+ else 0
+ )
+ json_priority = (
+ accepted_length - accepted.index("application/json")
+ if "application/json" in accepted
+ else 0
+ )
+
+ return html_priority > json_priority
+
+ def get_graphql_params(self, request, data):
+ if self.graphql_params:
+ return self.graphql_params
+
+ single_args = {}
+ for key in request.arguments.keys():
+ single_args[key] = self.decode_argument(
+ request.arguments.get(key)[0])
+
+ query = single_args.get("query") or data.get("query")
+ variables = single_args.get("variables") or data.get("variables")
+ _id = single_args.get("id") or data.get("id")
+
+ if variables and isinstance(variables, str):
+ try:
+ variables = json.loads(variables)
+ except Exception:
+ raise HTTPError(
+ status_code=400,
+ log_message="Variables are invalid JSON."
+ )
+
+ operation_name = (
+ single_args.get("operationName") or data.get("operationName")
+ )
+ if operation_name == "null":
+ operation_name = None
+
+ self.graphql_params = query, variables, operation_name, _id
+ return self.graphql_params
+
+ def handle_error(self, ex: Exception) -> None:
+ if not isinstance(ex, (web.HTTPError, ExecutionError, GraphQLError)):
+ tb = "".join(traceback.format_exception(*sys.exc_info()))
+ app_log.error("Error: {0} {1}".format(ex, tb))
+ self.set_status(self.error_status(ex))
+ error_json = json_encode({"errors": self.error_format(ex)})
+ app_log.debug("error_json: %s", error_json)
+ self.write(error_json)
+
+ @staticmethod
+ def error_status(exception: Exception) -> int:
+ if isinstance(exception, web.HTTPError):
+ return exception.status_code
+ elif isinstance(exception, (ExecutionError, GraphQLError)):
+ return 400
+ else:
+ return 500
+
+ @staticmethod
+ def error_format(exception: Exception) -> List[Dict[str, Any]]:
+ if isinstance(exception, ExecutionError):
+ return [{"message": e} for e in exception.errors]
+ elif isinstance(exception, GraphQLError):
+ return [{"message": exception.formatted["message"]}]
+ elif isinstance(exception, web.HTTPError):
+ return [{"message": exception.log_message}]
+ else:
+ return [{"message": "Unknown server error"}]
+
+ @staticmethod
+ def format_error(error):
+ if isinstance(error, GraphQLError):
+ return error.formatted
+
+ return {"message": str(error)}
diff --git a/cylc/uiserver/graphql/tornado_ws.py b/cylc/uiserver/graphql/tornado_ws.py
new file mode 100644
index 00000000..45b2514c
--- /dev/null
+++ b/cylc/uiserver/graphql/tornado_ws.py
@@ -0,0 +1,398 @@
+# The MIT License (MIT)
+#
+# Copyright (c) 2016-Present Syrus Akbary
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in
+# all copies or substantial portions of the Software.
+#
+# ----------------------------------------------------------------------------
+#
+# This code was derived from coded in the graphql-ws project, and code that
+# was offered to the graphql-ws project by members of the Cylc
+# development team but not merged.
+#
+# * https://github.com/graphql-python/graphql-ws
+# * https://github.com/graphql-python/graphql-ws/pull/25/files
+#
+# It has been evolved to suit and ported to graphql-core v3.
+
+import asyncio
+from asyncio.queues import QueueEmpty
+from contextlib import suppress
+from inspect import isawaitable
+import json
+from weakref import WeakSet
+
+from tornado.websocket import WebSocketClosedError
+from graphql import (
+ parse,
+ validate,
+ ExecutionResult,
+ GraphQLError,
+ MiddlewareManager,
+)
+from graphql.pyutils import is_awaitable
+
+from cylc.flow.network.graphql import instantiate_middleware
+from cylc.flow.network.graphql_subscribe import subscribe
+
+from cylc.uiserver.authorise import AuthorizationMiddleware
+from cylc.uiserver.schema import SUB_RESOLVER_MAPPING
+
+
+NO_MSG_DELAY = 1.0
+
+GRAPHQL_WS = "graphql-ws"
+WS_PROTOCOL = GRAPHQL_WS
+GQL_CONNECTION_INIT = "connection_init" # Client -> Server
+GQL_CONNECTION_ACK = "connection_ack" # Server -> Client
+GQL_CONNECTION_ERROR = "connection_error" # Server -> Client
+GQL_CONNECTION_TERMINATE = "connection_terminate" # Client -> Server
+GQL_CONNECTION_KEEP_ALIVE = "ka" # Server -> Client
+GQL_START = "start" # Client -> Server
+GQL_DATA = "data" # Server -> Client
+GQL_ERROR = "error" # Server -> Client
+GQL_COMPLETE = "complete" # Server -> Client
+GQL_STOP = "stop" # Client -> Server
+
+
+class ConnectionClosedException(Exception):
+ pass
+
+
+class TornadoConnectionContext:
+
+ def __init__(self, ws, request_context=None):
+ self.ws = ws
+ self.operations = {}
+ self.request_context = request_context
+ self.pending_tasks = WeakSet()
+
+ def has_operation(self, op_id):
+ return op_id in self.operations
+
+ def register_operation(self, op_id, async_iterator):
+ self.operations[op_id] = async_iterator
+
+ def get_operation(self, op_id):
+ return self.operations[op_id]
+
+ def remove_operation(self, op_id):
+ try:
+ return self.operations.pop(op_id)
+ except KeyError:
+ return
+
+ async def receive(self):
+ try:
+ return self.ws.recv_nowait()
+ except WebSocketClosedError:
+ raise ConnectionClosedException()
+
+ async def send(self, data):
+ if self.closed:
+ return
+ await self.ws.write_message(data)
+
+ @property
+ def closed(self):
+ return self.ws.close_code is not None
+
+ async def close(self, code):
+ await self.ws.close(code)
+
+ def remember_task(self, task):
+ self.pending_tasks.add(task)
+ # Clear completed tasks
+ self.pending_tasks -= WeakSet(
+ task for task in self.pending_tasks if task.done()
+ )
+
+ async def unsubscribe(self, op_id):
+ async_iterator = self._unsubscribe(op_id)
+ if (
+ getattr(async_iterator, "future", None)
+ and async_iterator.future.cancel()
+ ):
+ await async_iterator.future
+
+ def _unsubscribe(self, op_id):
+ async_iterator = self.remove_operation(op_id)
+ if hasattr(async_iterator, "dispose"):
+ async_iterator.dispose()
+ return async_iterator
+
+ async def unsubscribe_all(self):
+ awaitables = [
+ self.unsubscribe(op_id)
+ for op_id in list(self.operations)
+ ]
+ for task in self.pending_tasks:
+ task.cancel()
+ awaitables.append(task)
+ if awaitables:
+ with suppress(asyncio.CancelledError):
+ await asyncio.gather(*awaitables)
+
+
+class TornadoSubscriptionServer:
+
+ def __init__(
+ self,
+ schema,
+ keep_alive=True,
+ loop=None,
+ middleware=None,
+ execution_context_class=None,
+ auth=None
+ ):
+ self.schema = schema
+ self.loop = loop
+ self.middleware = middleware
+ self.execution_context_class = execution_context_class
+ self.auth = auth
+
+ async def execute(self, params):
+ # Parse query to document
+ try:
+ document = parse(params['query'])
+ except GraphQLError as error:
+ return ExecutionResult(data=None, errors=[error])
+
+ # Validate document against schema
+ validation_errors = validate(self.schema.graphql_schema, document)
+ if validation_errors:
+ return ExecutionResult(data=None, errors=validation_errors)
+
+ # execute subscription
+ return await subscribe(
+ self.schema.graphql_schema,
+ document,
+ **params['kwargs']
+ )
+
+ def process_message(self, connection_context, parsed_message):
+ task = asyncio.ensure_future(
+ self._process_message(connection_context, parsed_message),
+ loop=self.loop
+ )
+ connection_context.remember_task(task)
+ return task
+
+ async def _process_message(self, connection_context, parsed_message):
+ op_id = parsed_message.get("id")
+ op_type = parsed_message.get("type")
+ payload = parsed_message.get("payload")
+
+ if op_type == GQL_CONNECTION_INIT:
+ return await self.on_connection_init(
+ connection_context, op_id, payload
+ )
+
+ elif op_type == GQL_CONNECTION_TERMINATE:
+ return self.on_connection_terminate(connection_context, op_id)
+
+ elif op_type == GQL_START:
+ if not isinstance(payload, dict):
+ raise AssertionError("The payload must be a dict")
+ params = self.get_graphql_params(connection_context, payload)
+ return await self.on_start(connection_context, op_id, params)
+
+ elif op_type == GQL_STOP:
+ return await self.on_stop(connection_context, op_id)
+
+ else:
+ return await self.send_error(
+ connection_context,
+ op_id,
+ Exception("Invalid message type: {}.".format(op_type)),
+ )
+
+ async def on_connection_init(self, connection_context, op_id, payload):
+ try:
+ await self.on_connect(connection_context, payload)
+ await self.send_message(
+ connection_context, op_type=GQL_CONNECTION_ACK)
+ except Exception as e:
+ await self.send_error(
+ connection_context, op_id, e, GQL_CONNECTION_ERROR)
+ await connection_context.close(1011)
+
+ async def on_connect(self, connection_context, payload):
+ pass
+
+ def on_connection_terminate(self, connection_context, op_id):
+ return connection_context.close(1011)
+
+ def get_graphql_params(self, connection_context, payload):
+ # Create a new context object for each subscription,
+ # which allows it to carry a unique subscription id.
+ params = {
+ "variable_values": payload.get("variables"),
+ "operation_name": payload.get("operationName"),
+ "context_value": dict(
+ payload.get("context", connection_context.request_context)
+ ),
+ "subscribe_resolver_map": SUB_RESOLVER_MAPPING,
+ }
+ # If middleware get instantiated here (optional), they will
+ # be local/private to each subscription.
+ if self.middleware is not None:
+ middleware = list(
+ instantiate_middleware(self.middleware)
+ )
+ else:
+ middleware = self.middleware
+ for mw in self.middleware:
+ if mw == AuthorizationMiddleware:
+ mw.auth = self.auth
+ return {
+ 'query': payload.get("query"),
+ 'kwargs': dict(
+ params,
+ middleware=MiddlewareManager(
+ *middleware,
+ ),
+ execution_context_class=self.execution_context_class,
+ )
+ }
+
+ async def on_open(self, connection_context):
+ pass
+
+ async def on_stop(self, connection_context, op_id):
+ return await connection_context.unsubscribe(op_id)
+
+ async def on_close(self, connection_context):
+ return await connection_context.unsubscribe_all()
+
+ async def handle(self, ws, request_context=None):
+ await asyncio.shield(self._handle(ws, request_context))
+
+ async def _handle(self, ws, request_context=None):
+ connection_context = TornadoConnectionContext(ws, request_context)
+ await self.on_open(connection_context)
+ while True:
+ message = None
+ try:
+ if connection_context.closed:
+ raise ConnectionClosedException()
+ message = await connection_context.receive()
+ except QueueEmpty:
+ pass
+ except ConnectionClosedException:
+ break
+ if message:
+ await self.on_message(connection_context, message)
+ else:
+ await asyncio.sleep(NO_MSG_DELAY)
+
+ await self.on_close(connection_context)
+
+ async def on_start(self, connection_context, op_id, params):
+ # Attempt to unsubscribe first in case we already have a subscription
+ # with this id.
+ await connection_context.unsubscribe(op_id)
+
+ params['kwargs']['root_value'] = op_id
+ execution_result = await self.execute(params)
+ iterator = None
+ try:
+ if isawaitable(execution_result):
+ execution_result = await execution_result
+ if not hasattr(execution_result, '__aiter__'):
+ await self.send_execution_result(
+ connection_context, op_id, execution_result)
+ else:
+ iterator = execution_result.__aiter__()
+ connection_context.register_operation(op_id, iterator)
+ async for single_result in iterator:
+ if not connection_context.has_operation(op_id):
+ break
+ await self.send_execution_result(
+ connection_context, op_id, single_result)
+ except (GeneratorExit, asyncio.CancelledError):
+ raise
+ except Exception as e:
+ await self.send_error(connection_context, op_id, e)
+ finally:
+ if iterator:
+ await iterator.aclose()
+ await self.send_message(connection_context, op_id, GQL_COMPLETE)
+ await connection_context.unsubscribe(op_id)
+ await self.on_operation_complete(connection_context, op_id)
+
+ async def send_message(
+ self, connection_context, op_id=None, op_type=None, payload=None
+ ):
+ message = self.build_message(op_id, op_type, payload)
+ return await connection_context.send(message)
+
+ def build_message(self, _id, op_type, payload):
+ message = {}
+ if _id is not None:
+ message["id"] = _id
+ if op_type is not None:
+ message["type"] = op_type
+ if payload is not None:
+ message["payload"] = payload
+ if not message:
+ raise AssertionError("You need to send at least one thing")
+ return message
+
+ async def send_execution_result(
+ self, connection_context, op_id, execution_result):
+ # Resolve any pending promises
+ if is_awaitable(execution_result.data):
+ await execution_result.data
+ if execution_result.data and 'logs' not in execution_result.data:
+ request_context = connection_context.request_context
+ await request_context['resolvers'].flow_delta_processed(
+ request_context, op_id)
+
+ result = execution_result.formatted
+ return await self.send_message(
+ connection_context, op_id, GQL_DATA, result
+ )
+
+ async def on_operation_complete(self, connection_context, op_id):
+ # remove the subscription from the sub_statuses dict
+ with suppress(KeyError):
+ connection_context.request_context['sub_statuses'].pop(op_id)
+
+ async def send_error(
+ self, connection_context, op_id, error, error_type=None
+ ):
+ if error_type is None:
+ error_type = GQL_ERROR
+
+ if error_type not in {GQL_CONNECTION_ERROR, GQL_ERROR}:
+ raise AssertionError(
+ "error_type should be one of the allowed error messages"
+ " GQL_CONNECTION_ERROR or GQL_ERROR"
+ )
+
+ error_payload = {"message": str(error)}
+
+ return await self.send_message(
+ connection_context, op_id, error_type, error_payload)
+
+ async def on_message(self, connection_context, message):
+ try:
+ if not isinstance(message, dict):
+ parsed_message = json.loads(message)
+ if not isinstance(parsed_message, dict):
+ raise AssertionError("Payload must be an object.")
+ else:
+ parsed_message = message
+ except Exception as e:
+ return await self.send_error(connection_context, None, e)
+
+ return self.process_message(connection_context, parsed_message)
diff --git a/cylc/uiserver/handlers.py b/cylc/uiserver/handlers.py
index 5809f48b..4238a198 100644
--- a/cylc/uiserver/handlers.py
+++ b/cylc/uiserver/handlers.py
@@ -19,16 +19,9 @@
import json
import os
import re
-from typing import (
- TYPE_CHECKING,
- Callable,
- Dict,
-)
+from typing import TYPE_CHECKING, Callable, Dict, Awaitable, Optional
from cylc.flow import __version__ as cylc_flow_version
-from graphene_tornado.tornado_graphql_handler import TornadoGraphQLHandler
-from graphql import get_default_backend
-from graphql_ws.constants import GRAPHQL_WS
from jupyter_server.base.handlers import JupyterHandler
from tornado import (
web,
@@ -37,20 +30,17 @@
from tornado.ioloop import IOLoop
from cylc.uiserver import __version__
-from cylc.uiserver.authorise import (
- Authorization,
- AuthorizationMiddleware,
-)
+from cylc.uiserver.authorise import Authorization, AuthorizationMiddleware
+from cylc.uiserver.graphql import authenticated as websockets_authenticated
+from cylc.uiserver.graphql.tornado import TornadoGraphQLHandler
+from cylc.uiserver.graphql.tornado_ws import GRAPHQL_WS
from cylc.uiserver.utils import is_bearer_token_authenticated
-from cylc.uiserver.websockets import authenticated as websockets_authenticated
if TYPE_CHECKING:
- from graphql.execution import ExecutionResult
- from jupyter_server.auth.identity import User as JPSUser
-
from cylc.uiserver.resolvers import Resolvers
- from cylc.uiserver.websockets.tornado import TornadoSubscriptionServer
+ from cylc.uiserver.graphql.tornado_ws import TornadoSubscriptionServer
+ from jupyter_server.auth.identity import User as JPSUser
ME = getpass.getuser()
@@ -164,7 +154,7 @@ class CylcAppHandler(JupyterHandler):
def initialize(self, auth):
self.auth = auth
- super().initialize()
+ JupyterHandler.initialize(self)
@property
def hub_users(self):
@@ -322,32 +312,41 @@ class UIServerGraphQLHandler(CylcAppHandler, TornadoGraphQLHandler):
def set_default_headers(self) -> None:
self.set_header('Server', '')
- def initialize(self, schema=None, executor=None, middleware=None,
- root_value=None, graphiql=False, pretty=False,
- batch=False, backend=None, auth=None, **kwargs):
- super(TornadoGraphQLHandler, self).initialize()
- self.auth = auth
- self.schema = schema
+ def initialize(
+ self,
+ schema,
+ middleware=None,
+ root_value=None,
+ pretty=False,
+ batch=False,
+ execution_context_class=None,
+ validation_rules=None,
+ auth=None,
+ **kwargs,
+ ):
+ TornadoGraphQLHandler.initialize(
+ self,
+ schema,
+ middleware=middleware,
+ root_value=root_value,
+ pretty=pretty,
+ batch=batch,
+ execution_context_class=execution_context_class,
+ validation_rules=validation_rules,
+ )
+ CylcAppHandler.initialize(self, auth)
- if middleware is not None:
- self.middleware = list(self.instantiate_middleware(middleware))
# Make authorization info available to auth middleware
for mw in self.middleware:
if isinstance(mw, AuthorizationMiddleware):
mw.auth = self.auth
- self.executor = executor
- self.root_value = root_value
- self.pretty = pretty
- self.graphiql = graphiql
- self.batch = batch
- self.backend = backend or get_default_backend()
+
# Set extra attributes
for key, value in kwargs.items():
if hasattr(self, key):
setattr(self, key, value)
- @property
- def context(self):
+ def get_context(self):
"""The GraphQL context passed to resolvers (incl middleware)."""
return {
'graphql_params': self.graphql_params,
@@ -357,15 +356,10 @@ def context(self):
}
@web.authenticated # type: ignore[arg-type]
- async def execute(self, *args, **kwargs) -> 'ExecutionResult':
- # Use own backend, and TornadoGraphQLHandler already does validation.
- return await self.schema.execute(
- *args,
- backend=self.backend,
- variable_values=kwargs.get('variables'),
- validate=False,
- **kwargs,
- )
+ async def execute(
+ self, *args, **kwargs
+ ) -> Optional[Awaitable[None]]:
+ return await TornadoGraphQLHandler.execute(self, *args, **kwargs)
@web.authenticated
async def run(self, *args, **kwargs):
diff --git a/cylc/uiserver/resolvers.py b/cylc/uiserver/resolvers.py
index 056d4eaf..f322035b 100644
--- a/cylc/uiserver/resolvers.py
+++ b/cylc/uiserver/resolvers.py
@@ -16,6 +16,7 @@
"""GraphQL resolvers for use in data accessing and mutation of workflows."""
import asyncio
+from enum import Enum
from contextlib import suppress
from copy import deepcopy
import errno
@@ -42,17 +43,14 @@
Union,
)
+from graphql.language import print_ast
+import psutil
+
from cylc.flow.data_store_mgr import WORKFLOW
from cylc.flow.exceptions import CylcError
from cylc.flow.id import Tokens
from cylc.flow.network.resolvers import BaseResolvers
-from cylc.flow.scripts.clean import (
- CleanOptions,
- run,
-)
-from graphql.language.base import print_ast
-import psutil
-
+from cylc.flow.scripts.clean import CleanOptions, run
if TYPE_CHECKING:
from concurrent.futures import Executor
@@ -61,7 +59,7 @@
from cylc.flow.data_store_mgr import DataStoreMgr
from cylc.flow.option_parsers import Options
- from graphql import ResolveInfo
+ from graphql import GraphQLResolveInfo
from cylc.uiserver.app import CylcUIServer
from cylc.uiserver.workflows_mgr import WorkflowsManager
@@ -142,6 +140,8 @@ def _build_cmd(cmd: List, args: Dict) -> List:
if isinstance(value, int) and not isinstance(value, bool):
# Any integer items need converting to strings:
value = str(value)
+ elif isinstance(value, Enum):
+ value = value.value
value = [value]
for item in value:
cmd.append(key)
@@ -537,7 +537,7 @@ def __init__(
# Mutations
async def mutator(
self,
- info: 'ResolveInfo',
+ info: 'GraphQLResolveInfo',
command: str,
w_args: Dict[str, Any],
_kwargs: Dict[str, Any],
@@ -562,7 +562,8 @@ async def mutator(
# Create a modified request string,
# containing only the current mutation/field.
operation_ast = deepcopy(info.operation)
- operation_ast.selection_set.selections = info.field_asts
+ operation_ast.selection_set.selections = tuple(
+ n for n in info.field_nodes)
graphql_args = {
'request_string': print_ast(operation_ast),
@@ -574,12 +575,19 @@ async def mutator(
async def service(
self,
- info: 'ResolveInfo',
+ info: 'GraphQLResolveInfo',
command: str,
workflows: Iterable['Tokens'],
kwargs: Dict[str, Any],
) -> List[Union[bool, str]]:
+ # GraphQL v3 includes all variables that are set, even if set to null.
+ kwargs = {
+ k: v
+ for k, v in kwargs.items()
+ if v is not None
+ }
+
if command == 'clean': # noqa: SIM116
return await Services.clean(
workflows,
@@ -605,7 +613,7 @@ async def service(
async def subscription_service(
self,
- info: 'ResolveInfo',
+ info: 'GraphQLResolveInfo',
_command: str,
ids: List[Tokens],
file=None
@@ -652,7 +660,7 @@ def kill_process_tree(
async def list_log_files(
root: Optional[Any],
- info: 'ResolveInfo',
+ info: 'GraphQLResolveInfo',
id: str, # noqa: required to match schema arg name
):
tokens = Tokens(id)
@@ -665,7 +673,7 @@ async def list_log_files(
async def stream_log(
root: Optional[Any],
- info: 'ResolveInfo',
+ info: 'GraphQLResolveInfo',
*,
command='cat_log',
id: str, # noqa: required to match schema arg name
diff --git a/cylc/uiserver/schema.py b/cylc/uiserver/schema.py
index f1c73378..ef8b3b11 100644
--- a/cylc/uiserver/schema.py
+++ b/cylc/uiserver/schema.py
@@ -30,6 +30,10 @@
Tuple,
)
+import graphene
+from graphene.types.generic import GenericScalar
+from graphene.types.schema import identity_resolve
+
from cylc.flow.data_store_mgr import (
JOBS,
TASKS,
@@ -38,6 +42,7 @@
from cylc.flow.network.schema import (
NODE_MAP,
STRIP_NULL_DEFAULT,
+ SUB_RESOLVER_MAPPING,
CyclePoint,
GenericResponse,
Job,
@@ -47,7 +52,7 @@
Subscriptions,
Task,
WorkflowID,
- WorkflowRunMode as RunMode,
+ WorkflowRunMode,
_mut_field,
get_nodes_all,
process_resolver_info,
@@ -64,8 +69,6 @@
)
from cylc.flow.util import sstrip
from cylc.flow.workflow_files import WorkflowFiles
-import graphene
-from graphene.types.generic import GenericScalar
from cylc.uiserver.resolvers import (
Resolvers,
@@ -75,12 +78,12 @@
if TYPE_CHECKING:
- from graphql import ResolveInfo
+ from graphql import GraphQLResolveInfo
async def mutator(
root: Optional[Any],
- info: 'ResolveInfo',
+ info: 'GraphQLResolveInfo',
*,
command: str,
workflows: Optional[List[str]] = None,
@@ -98,7 +101,11 @@ async def mutator(
resolvers: 'Resolvers' = (
info.context.get('resolvers') # type: ignore[union-attr]
)
- res = await resolvers.service(info, command, parsed_workflows, kwargs)
+ try:
+ res = await resolvers.service(info, command, parsed_workflows, kwargs)
+ except Exception as exc:
+ resolvers.log.exception(exc)
+ raise
return GenericResponse(result=res)
@@ -167,9 +174,7 @@ class Arguments:
Hold all tasks after this cycle point.
''')
)
- mode = RunMode(
- default_value=RunMode.Live.name
- )
+ mode = WorkflowRunMode(default_value=WorkflowRunMode.Live)
host = graphene.String(
description=sstrip('''
Specify the host on which to start-up the workflow. If not
@@ -845,6 +850,16 @@ class LogFiles(graphene.ObjectType):
)
+# TODO: Change to use subscribe arg/default.
+# See https://github.com/cylc/cylc-flow/issues/6688
+# graphql-core has a subscribe field for both Meta and Field,
+# graphene at v3.4.3 does not. As a workaround
+# the subscribe function is looked up via the following mapping:
+SUB_RESOLVER_MAPPING.update({
+ 'logs': stream_log, # type: ignore
+})
+
+
class UISSubscriptions(Subscriptions):
# Example graphiql workflow log subscription:
# subscription {
@@ -872,7 +887,7 @@ class Logs(graphene.ObjectType):
required=False,
description='File name of job log to fetch, e.g. job.out'
),
- resolver=stream_log
+ resolver=identity_resolve
)
diff --git a/cylc/uiserver/tests/test_graphql.py b/cylc/uiserver/tests/test_graphql.py
index d0b26d0b..7ff0c7d9 100644
--- a/cylc/uiserver/tests/test_graphql.py
+++ b/cylc/uiserver/tests/test_graphql.py
@@ -161,3 +161,18 @@ async def _log(*args, **kwargs):
}
}
''').strip()
+
+ # issue clean mutation
+ response = await gql_query(
+ *('cylc', 'graphql'),
+ query='''
+ mutation {
+ clean(workflows: ["%s"]){
+ result
+ }
+ }
+ ''' % (
+ Tokens(user='me', workflow='foo').id,
+ ),
+ )
+ assert response.code == 200
diff --git a/cylc/uiserver/tests/test_handlers.py b/cylc/uiserver/tests/test_handlers.py
index e07d6066..614550d8 100644
--- a/cylc/uiserver/tests/test_handlers.py
+++ b/cylc/uiserver/tests/test_handlers.py
@@ -20,11 +20,11 @@
from unittest.mock import MagicMock
import pytest
-from graphql_ws.constants import GRAPHQL_WS
from tornado.httputil import HTTPServerRequest
from tornado.testing import AsyncHTTPTestCase, get_async_test_timeout
from tornado.web import Application
+from cylc.uiserver.graphql.tornado_ws import GRAPHQL_WS
from cylc.uiserver.handlers import SubscriptionHandler
diff --git a/cylc/uiserver/tests/test_resolvers.py b/cylc/uiserver/tests/test_resolvers.py
index 36a41901..026ec1bd 100644
--- a/cylc/uiserver/tests/test_resolvers.py
+++ b/cylc/uiserver/tests/test_resolvers.py
@@ -16,7 +16,6 @@
import asyncio
from concurrent.futures import ThreadPoolExecutor
from typing import Any, Dict, List, Tuple
-from async_timeout import timeout
import logging
import os
import pytest
@@ -24,6 +23,12 @@
from subprocess import Popen, TimeoutExpired
from types import SimpleNamespace
+import sys
+if sys.version_info >= (3, 11):
+ from asyncio import timeout
+else:
+ from async_timeout import timeout
+
from cylc.flow import CYLC_LOG
from cylc.flow.exceptions import CylcError
from cylc.flow.id import Tokens
diff --git a/cylc/uiserver/tests/test_schema.py b/cylc/uiserver/tests/test_schema.py
index 4c03be36..d1e393c7 100644
--- a/cylc/uiserver/tests/test_schema.py
+++ b/cylc/uiserver/tests/test_schema.py
@@ -14,9 +14,10 @@
# along with this program. If not, see .
import importlib
+import pytest
import cylc.uiserver.schema
-from cylc.uiserver.schema import NODE_MAP as UIS_NODE_MAP
+from cylc.uiserver.schema import NODE_MAP as UIS_NODE_MAP, mutator
def test_node_map():
@@ -30,3 +31,19 @@ def test_node_map():
_class = getattr(cylc.uiserver.schema, type_name)
# It is not straightforward to check that _class is a Graphene class
assert 'graphene' in type(_class).__module__
+
+
+async def test_mutator(cylc_uis):
+ """Test exception and arg variants."""
+ class Info:
+ context = {'resolvers': cylc_uis.resolvers, 'id': 1}
+
+ with pytest.raises(Exception) as exc:
+ await mutator(
+ None,
+ Info(),
+ command='NotACommand',
+ workflows=None,
+ args={'This': 'That'}
+ )
+ assert 'NotImplementedError' in f'{exc}'
diff --git a/cylc/uiserver/websockets/resolve.py b/cylc/uiserver/websockets/resolve.py
deleted file mode 100644
index 71bb8f23..00000000
--- a/cylc/uiserver/websockets/resolve.py
+++ /dev/null
@@ -1,67 +0,0 @@
-# MIT License
-#
-# Copyright (c) 2017, Syrus Akbary
-#
-# Permission is hereby granted, free of charge, to any person obtaining a copy
-# of this software and associated documentation files (the "Software"), to deal
-# in the Software without restriction, including without limitation the rights
-# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
-# copies of the Software, and to permit persons to whom the Software is
-# furnished to do so, subject to the following conditions:
-#
-# The above copyright notice and this permission notice shall be included in
-# all copies or substantial portions of the Software.
-#
-# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
-# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
-# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
-# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
-# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
-# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
-# SOFTWARE.
-
-"""
-This file contains an implementation of "resolve" derived from the one
-found in the graphql-ws library with the above license.
-
-This is temporary code until the change makes its way upstream.
-"""
-
-# NOTE: transient dependency from graphql-ws purposefully not
-# reflected in cylc-uiserver dependencies
-from promise import Promise
-
-from graphql_ws.base_async import is_awaitable
-
-
-async def resolve(
- data,
- _container=None,
- _key=None,
-):
- """
- Wait on any awaitable children of a data element and resolve
- any Promises.
- """
- stack = [(data, _container, _key)]
-
- while stack:
- _data, _container, _key = stack.pop()
-
- if is_awaitable(_data):
- _data = await _data
- if isinstance(_data, Promise):
- _data = _data.value
- if _container is not None:
- _container[_key] = _data
- if isinstance(_data, dict):
- items = _data.items()
- elif isinstance(_data, list):
- items = enumerate(_data)
- else:
- items = None
- if items is not None:
- stack.extend([
- (child, _data, key)
- for key, child in items
- ])
diff --git a/cylc/uiserver/websockets/tornado.py b/cylc/uiserver/websockets/tornado.py
deleted file mode 100644
index aaf2e285..00000000
--- a/cylc/uiserver/websockets/tornado.py
+++ /dev/null
@@ -1,178 +0,0 @@
-# This file is a temporary solution for subscriptions with graphql_ws and
-# Tornado, from the following pending PR to graphql-ws:
-# https://github.com/graphql-python/graphql-ws/pull/25/files
-# The file was copied from this revision:
-# https://github.com/graphql-python/graphql-ws/blob/cf560b9a5d18d4a3908dc2cfe2199766cc988fef/graphql_ws/tornado.py
-
-from contextlib import suppress
-import getpass
-from inspect import isawaitable, isclass
-import socket
-
-from asyncio import create_task, gather, wait, shield, sleep
-from asyncio.queues import QueueEmpty
-from tornado.websocket import WebSocketClosedError
-from graphql.execution.middleware import MiddlewareManager
-from graphql_ws.base import ConnectionClosedException, BaseSubscriptionServer
-from graphql_ws.base_async import (
- BaseAsyncConnectionContext,
- BaseAsyncSubscriptionServer
-)
-from graphql_ws.observable_aiter import setup_observable_extension
-from graphql_ws.constants import (
- GQL_CONNECTION_ACK,
- GQL_CONNECTION_ERROR,
- GQL_COMPLETE
-)
-
-from typing import Union, Awaitable, Any, List, Tuple, Dict, Optional
-
-from cylc.uiserver.authorise import AuthorizationMiddleware
-from cylc.uiserver.websockets.resolve import resolve
-
-
-setup_observable_extension()
-
-NO_MSG_DELAY = 1.0
-
-
-class TornadoConnectionContext(BaseAsyncConnectionContext):
- async def receive(self):
- try:
- return self.ws.recv_nowait()
- except WebSocketClosedError:
- raise ConnectionClosedException()
-
- async def send(self, data):
- if self.closed:
- return
- await self.ws.write_message(data)
-
- @property
- def closed(self):
- return self.ws.close_code is not None
-
- async def close(self, code):
- await self.ws.close(code)
-
-
-class TornadoSubscriptionServer(BaseAsyncSubscriptionServer):
- def __init__(
- self, schema,
- keep_alive=True,
- loop=None,
- backend=None,
- middleware=None,
- auth=None
- ):
- self.loop = loop
- self.backend = backend or None
- self.middleware = middleware
- self.auth = auth
- super().__init__(schema, keep_alive)
-
- @staticmethod
- def instantiate_middleware(middlewares):
- for middleware in middlewares:
- if isclass(middleware):
- yield middleware()
- continue
- yield middleware
-
- def get_graphql_params(self, *args, **kwargs):
- params = super(TornadoSubscriptionServer,
- self).get_graphql_params(*args, **kwargs)
- # If middleware get instantiated here (optional), they will
- # be local/private to each subscription.
- if self.middleware is not None:
- middleware = list(
- self.instantiate_middleware(self.middleware)
- )
- else:
- middleware = self.middleware
- for mw in self.middleware:
- if mw == AuthorizationMiddleware:
- mw.auth = self.auth
- return dict(
- params,
- return_promise=True,
- backend=self.backend,
- middleware=MiddlewareManager(
- *middleware,
- wrap_in_promise=False
- ),
- )
-
- async def _handle(self, ws, request_context=None):
- connection_context = TornadoConnectionContext(ws, request_context)
- await self.on_open(connection_context)
- while True:
- message = None
- try:
- if connection_context.closed:
- raise ConnectionClosedException()
- message = await connection_context.receive()
- except QueueEmpty:
- pass
- except ConnectionClosedException:
- break
- if message:
- self.on_message(connection_context, message)
- else:
- await sleep(NO_MSG_DELAY)
-
- await self.on_close(connection_context)
-
- async def handle(self, ws, request_context=None):
- await shield(self._handle(ws, request_context))
-
- async def on_start(self, connection_context, op_id, params):
- # Attempt to unsubscribe first in case we already have a subscription
- # with this id.
- await connection_context.unsubscribe(op_id)
-
- params['root_value'] = op_id
- execution_result = self.execute(params)
- try:
- if isawaitable(execution_result):
- execution_result = await execution_result
- if not hasattr(execution_result, '__aiter__'):
- await self.send_execution_result(
- connection_context, op_id, execution_result)
- else:
- iterator = await execution_result.__aiter__()
- connection_context.register_operation(op_id, iterator)
- async for single_result in iterator:
- if not connection_context.has_operation(op_id):
- break
- await self.send_execution_result(
- connection_context, op_id, single_result)
- except Exception as e:
- await self.send_error(connection_context, op_id, e)
- await self.send_message(connection_context, op_id, GQL_COMPLETE)
- await connection_context.unsubscribe(op_id)
- await self.on_operation_complete(connection_context, op_id)
-
- async def on_operation_complete(self, connection_context, op_id):
- # remove the subscription from the sub_statuses dict
- with suppress(KeyError):
- connection_context.request_context['sub_statuses'].pop(op_id)
-
-
- async def send_execution_result(self, connection_context, op_id, execution_result):
- # Resolve any pending promises
- if execution_result.data and 'logs' not in execution_result.data:
- await resolve(execution_result.data)
- request_context = connection_context.request_context
- await request_context['resolvers'].flow_delta_processed(request_context, op_id)
- else:
- await resolve(execution_result.data)
-
- # NOTE: skip TornadoSubscriptionServer.send_execution_result because it
- # calls "resolve" then invokes BaseSubscriptionServer.send_execution_result
- await BaseSubscriptionServer.send_execution_result(
- self,
- connection_context,
- op_id,
- execution_result,
- )
diff --git a/mypy.ini b/mypy.ini
index 33ce48d4..9326a592 100644
--- a/mypy.ini
+++ b/mypy.ini
@@ -3,7 +3,7 @@ python_version = 3.8
ignore_missing_imports = True
files = cylc/uiserver
# don't run mypy on these files directly
-exclude = cylc/uiserver/(tests/|websockets/tornado.py|jupyter_config.py)
+exclude = cylc/uiserver/(tests/|jupyter_config.py)
# Enable PEP 420 style namespace packages, which we use.
# Needed for associating "import foo.bar" with foo/bar.py
diff --git a/setup.cfg b/setup.cfg
index 9006b989..77370054 100644
--- a/setup.cfg
+++ b/setup.cfg
@@ -47,13 +47,8 @@ install_requires =
# dependencies first. This way, if other dependencies (e.g. jupyterhub)
# don't pin versions, we will get whatever cylc-flow needs, and not the
# bleeding-edge version.
- # NB: no graphene version specified; we only make light use of it in our
- # own code, so graphene-tornado's transitive version should do.
cylc-flow==8.5.*
ansimarkup>=1.0.0
- graphene
- graphene-tornado==2.6.*
- graphql-ws==0.4.4
importlib-resources>=1.3.0; python_version < "3.9"
jupyter_server>=2.7
requests
@@ -64,11 +59,7 @@ install_requires =
# Transitive dependencies that we directly (lightly) use:
pyzmq
- graphql-core
-
- # Fix lack of upper pin for rx in graphql-core<2.3.0 (remove when
- # upgrading graphene to 3.0):
- rx<2
+ graphene
[options.packages.find]
include = cylc*
diff --git a/tox.ini b/tox.ini
index 91d4b7fb..9542f66d 100644
--- a/tox.ini
+++ b/tox.ini
@@ -8,7 +8,6 @@ ignore=
W503,
; line break after binary operator
W504
-
exclude=
build,
dist,
@@ -18,6 +17,5 @@ exclude=
.tox,
.eggs,
cylc/uiserver/jupyter_config.py,
- cylc/uiserver/websockets/tornado.py
paths =
./cylc/uiserver