From 5c04ef9f1b93130ade714437195470cef51b355a Mon Sep 17 00:00:00 2001 From: Nicholas Bollweg Date: Mon, 7 Jan 2019 19:29:58 -0500 Subject: [PATCH 1/6] get deps up-to-date, add graphql-ws --- .gitignore | 1 - anaconda-project.yml | 40 ++++++++++++++++++++++++++++++++++++++++ environment.yml | 12 ++++-------- postBuild | 0 setup.cfg | 8 ++++++-- 5 files changed, 50 insertions(+), 11 deletions(-) create mode 100644 anaconda-project.yml mode change 100644 => 100755 postBuild diff --git a/.gitignore b/.gitignore index 5699ba7..d601e6f 100644 --- a/.gitignore +++ b/.gitignore @@ -12,4 +12,3 @@ envs/ lib/ Untitled*.ipynb static/ -anaconda-project.yml diff --git a/anaconda-project.yml b/anaconda-project.yml new file mode 100644 index 0000000..36a1cf7 --- /dev/null +++ b/anaconda-project.yml @@ -0,0 +1,40 @@ +name: jupyter-graphql-dev + +commands: + lab: + unix: jupyter lab --no-browser --debug + setup: + unix: pip install -e . --ignore-installed --no-deps + black: + unix: black src/py setup.py + atom: + unix: atom . + static: + unix: python -m jupyter_graphql.fetch_static + +env_specs: + default: + platforms: + - linux-64 + - osx-64 + - win-64 + inherit_from: + - jupyter-graphql-dev + packages: + - black + - flake8 + - beautysh + jupyter-graphql-dev: + packages: + - gql + - graphene + - iso8601 + - jupyterlab >=0.35,<0.36 + - python >=3.6,<3.7 + - requests + - werkzeug + - pip: + - graphql-ws + channels: + - conda-forge + - defaults diff --git a/environment.yml b/environment.yml index 08794bd..08933da 100644 --- a/environment.yml +++ b/environment.yml @@ -5,16 +5,12 @@ channels: - defaults dependencies: - - aniso8601 + - gql + - graphene + - iso8601 - jupyterlab >=0.35,<0.36 - - pip - - promise - python >=3.6,<3.7 - requests - - rx - werkzeug - pip: - - gql - - graphene - - graphql-core - - graphql-relay + - graphql-ws diff --git a/postBuild b/postBuild old mode 100644 new mode 100755 diff --git a/setup.cfg b/setup.cfg index 73a555f..458abe0 100644 --- a/setup.cfg +++ b/setup.cfg @@ -25,8 +25,12 @@ classifiers = [options] install_requires = - jupyterlab - graphene-tornado + graphene + graphql-ws + iso8601 + notebook + werkzeug + package_dir = = src/py packages = find: From 4b56716e39a696907720b8dc6399e47473be18b9 Mon Sep 17 00:00:00 2001 From: Nicholas Bollweg Date: Mon, 7 Jan 2019 21:07:05 -0500 Subject: [PATCH 2/6] install subscriptions --- .gitignore | 1 - anaconda-project.yml | 1 + environment.yml | 1 + src/py/jupyter_graphql/__init__.py | 16 +- src/py/jupyter_graphql/executor.py | 33 -- src/py/jupyter_graphql/fetch_static.py | 30 +- src/py/jupyter_graphql/handlers.py | 368 ++---------------- src/py/jupyter_graphql/subscriptions.py | 116 ++++++ .../jupyter_graphql/templates/graphiql.html | 114 +++--- 9 files changed, 239 insertions(+), 441 deletions(-) delete mode 100644 src/py/jupyter_graphql/executor.py create mode 100644 src/py/jupyter_graphql/subscriptions.py diff --git a/.gitignore b/.gitignore index d601e6f..38257ba 100644 --- a/.gitignore +++ b/.gitignore @@ -5,7 +5,6 @@ _scripts/ *_files/ *.bundle.* *.egg-info/ -*.html *.log *.tar.gz envs/ diff --git a/anaconda-project.yml b/anaconda-project.yml index 36a1cf7..7b73695 100644 --- a/anaconda-project.yml +++ b/anaconda-project.yml @@ -35,6 +35,7 @@ env_specs: - werkzeug - pip: - graphql-ws + - graphene-tornado channels: - conda-forge - defaults diff --git a/environment.yml b/environment.yml index 08933da..9731087 100644 --- a/environment.yml +++ b/environment.yml @@ -14,3 +14,4 @@ dependencies: - werkzeug - pip: - graphql-ws + - graphene-tornado diff --git a/src/py/jupyter_graphql/__init__.py b/src/py/jupyter_graphql/__init__.py index 0de1394..2b7eeb2 100644 --- a/src/py/jupyter_graphql/__init__.py +++ b/src/py/jupyter_graphql/__init__.py @@ -1,11 +1,10 @@ from pathlib import Path from notebook.base.handlers import FileFindHandler - from notebook.utils import url_path_join as ujoin -from .handlers import GraphQLHandler - +from .subscriptions import TornadoSubscriptionServer +from .handlers import GraphQLHandler, SubscriptionHandler from .schema import schema @@ -18,6 +17,8 @@ def load_jupyter_server_extension(app): app.log.info("[graphql] initializing") web_app = app.web_app + subscription_server = TornadoSubscriptionServer(schema) + # add our templates web_app.settings["jinja2_env"].loader.searchpath += [TEMPLATES] @@ -27,7 +28,6 @@ def base(*bits): def app_middleware(next, root, info, **args): setattr(info.context, "_app", app) return next(root, info, **args) - web_app.add_handlers( ".*$", [ @@ -37,10 +37,16 @@ def app_middleware(next, root, info, **args): dict( schema=schema, graphiql=True, - nb_app=app, middleware=[app_middleware], ), ), + ( + base("subscriptions"), + SubscriptionHandler, + dict( + subscription_server=subscription_server + ), + ), # serve the graphiql assets (base("static", "(.*)"), FileFindHandler, dict(path=[STATIC])), ], diff --git a/src/py/jupyter_graphql/executor.py b/src/py/jupyter_graphql/executor.py deleted file mode 100644 index 4263bc8..0000000 --- a/src/py/jupyter_graphql/executor.py +++ /dev/null @@ -1,33 +0,0 @@ -""" -Copied from: -https://github.com/dronedeploy/graphene-tornado/blob/master/graphene_tornado/tornado_executor.py -""" -from promise import Promise -from tornado.gen import convert_yielded, multi_future -from tornado.ioloop import IOLoop -from tornado.concurrent import is_future - - -# https://gist.github.com/isi-gach/daef0b34ec5af6f026af52d593131c64 -class TornadoExecutor(object): - def __init__(self, io_loop=None): - if io_loop is None: - io_loop = IOLoop.current() - self.loop = io_loop - self.futures = [] - - def wait_until_finished(self): - # if there are futures to wait for - while self.futures: - # wait for the futures to finish - futures = self.futures - self.futures = [] - self.loop.run_sync(lambda: multi_future(futures)) - - def execute(self, fn, *args, **kwargs): - result = fn(*args, **kwargs) - if is_future(result): - future = convert_yielded(result) - self.futures.append(future) - return Promise.resolve(future) - return result diff --git a/src/py/jupyter_graphql/fetch_static.py b/src/py/jupyter_graphql/fetch_static.py index 9600156..c2bcb37 100644 --- a/src/py/jupyter_graphql/fetch_static.py +++ b/src/py/jupyter_graphql/fetch_static.py @@ -1,28 +1,44 @@ from urllib.request import urlretrieve from urllib.parse import urlparse +import sys from . import STATIC # Download the file from `url` and save it locally under `file_name`: -ASSETS = [ +JSDELIVR_ASSETS = [ "https://cdn.jsdelivr.net/npm/graphiql@0.11.10/graphiql.css", "https://cdn.jsdelivr.net/npm/whatwg-fetch@2.0.3/fetch.min.js", "https://cdn.jsdelivr.net/npm/react@16.2.0/umd/react.production.min.js", "https://cdn.jsdelivr.net/npm/react-dom@16.2.0/umd/react-dom.production.min.js", "https://cdn.jsdelivr.net/npm/graphiql@0.11.10/graphiql.min.js", + # "https://cdn.jsdelivr.net/npm/subscriptions-transport-ws@0.7.0/browser/client.js", + # "https://cdn.jsdelivr.net/npm/graphiql-subscriptions-fetcher@0.0.2/dist/fetcher.js", +] + +UNPKG_ASSETS = [ + "https://unpkg.com/subscriptions-transport-ws@0.7.0/browser/client.js", + "https://unpkg.com/graphiql-subscriptions-fetcher@0.0.2/browser/client.js" ] -def fetch_static(): - for url in ASSETS: - out = (STATIC / urlparse(url).path[1:]).resolve() - if not out.exists(): +def fetch_assets(assets, prefix=None, force=False): + for url in assets: + out = STATIC + if prefix: + out = STATIC / prefix + out = (out / urlparse(url).path[1:]).resolve() + if force or not out.exists(): out.parent.mkdir(parents=True, exist_ok=True) out.write_text("") - print("fetching", url, "to", out) + print(f"fetching\n\t- {url}\n\t> {out.relative_to(STATIC)}") urlretrieve(url, out) +def fetch_static(force=False): + fetch_assets(JSDELIVR_ASSETS, force=force) + fetch_assets(UNPKG_ASSETS, "npm", force=force) + + if __name__ == "__main__": - fetch_static() + fetch_static(force="--force" in sys.argv) diff --git a/src/py/jupyter_graphql/handlers.py b/src/py/jupyter_graphql/handlers.py index 7b1cbeb..cb52e7b 100644 --- a/src/py/jupyter_graphql/handlers.py +++ b/src/py/jupyter_graphql/handlers.py @@ -1,361 +1,43 @@ -""" -Forked from: -https://github.com/dronedeploy/graphene-tornado/blob/master/graphene_tornado/tornado_graphql_handler.py -""" -from __future__ import absolute_import, division, print_function - -import inspect -import sys -import traceback -import json - -from tornado import web -from tornado.escape import json_encode, json_decode -from tornado.gen import coroutine, Return -from tornado.locks import Event -from tornado.log import app_log -from tornado.web import HTTPError - -from werkzeug.datastructures import MIMEAccept -from werkzeug.http import parse_accept_header - -from graphql import parse, validate, Source, get_operation_ast, execute -from graphql.error import GraphQLError -from graphql.error import format_error as format_graphql_error -from graphql.execution import ExecutionResult +from asyncio import Queue from notebook.base.handlers import IPythonHandler -from .executor import TornadoExecutor +from tornado import websocket, ioloop +from graphene_tornado.tornado_graphql_handler import TornadoGraphQLHandler +from graphene_tornado import render_graphiql +from graphql_ws.constants import GRAPHQL_WS -class ExecutionError(Exception): - def __init__(self, status_code=400, errors=None): - self.status_code = status_code - if errors is None: - self.errors = [] - else: - self.errors = [str(e) for e in errors] - self.message = "\n".join(self.errors) +from pathlib import Path +here = Path(__file__).parent +render_graphiql.TEMPLATE = (here / "templates" / "graphiql.html").read_text() -class GraphQLHandler(IPythonHandler): - executor = None - schema = None - batch = False - middleware = [] - pretty = False - root_value = None - graphiql = False - graphiql_version = None - graphiql_template = None - graphiql_html_title = None - - def initialize( - self, - schema=None, - executor=None, - middleware=None, - root_value=None, - graphiql=False, - pretty=False, - batch=False, - nb_app=None, - ): - super(GraphQLHandler, self).initialize() - self.schema = schema - if middleware is not None: - self.middleware = list(self.instantiate_middleware(middleware)) - self.executor = executor - self.root_value = root_value - self.pretty = pretty - self.graphiql = graphiql - self.batch = batch - self.nb_app = nb_app +class GraphQLHandler(TornadoGraphQLHandler, IPythonHandler): def check_xsrf_cookie(self, *args, **kwargs): return True - @web.authenticated - @coroutine - def get(self): - try: - yield self.run("get") - except Exception as ex: - self.handle_error(ex) - - @web.authenticated - @coroutine - def post(self): - try: - yield self.run("post") - except Exception as ex: - self.log.error("[graphql] ERR: %s", ex) - self.handle_error(ex) - - @coroutine - def run(self, method): - show_graphiql = self.graphiql and self.should_display_graphiql() - data = self.parse_body() - - if self.batch: - responses = [] - for entry in data: - r = yield self.get_response(entry, method, entry) - responses.append(r) - result = "[{}]".format(",".join([response[0] for response in responses])) - print(data, responses) - status_code = max(responses, key=lambda response: response[1])[1] - else: - result, status_code = yield self.get_response(data, method, show_graphiql) - - if show_graphiql: - query, variables, operation_name, id = self.get_graphql_params( - self.request, data - ) - self.finish( - self.render_template( - "graphiql.html", base_url=self.application.settings["base_url"] - ) - ) - return - - self.set_status(status_code) - self.set_header("Content-Type", "application/json") - self.write(result) - self.finish() - - def parse_body(self): - content_type = self.content_type - - if content_type == "application/graphql": - return {"query": self.request.body} - - elif content_type == "application/json": - # noinspection PyBroadException - try: - body = self.request.body - except Exception as e: - raise ExecutionError(400, e) - - try: - request_json = json_decode(body) - if self.batch: - assert isinstance(request_json, list), ( - "Batch requests should receive a list, but received {}." - ).format(repr(request_json)) - assert ( - len(request_json) > 0 - ), "Received an empty list in the batch request." - else: - assert isinstance( - request_json, dict - ), "The received data is not a valid JSON query." - return request_json - except AssertionError as e: - raise HTTPError(status_code=400, log_message=str(e)) - except (TypeError, ValueError): - raise HTTPError( - status_code=400, log_message="POST body sent invalid JSON." - ) - - elif content_type in [ - "application/x-www-form-urlencoded", - "multipart/form-data", - ]: - return self.request.query_arguments - - return {} - - @coroutine - def get_response(self, data, method, show_graphiql=False): - query, variables, operation_name, id = self.get_graphql_params( - self.request, data + def render_graphiql(self, query, variables, operation_name, result): + return self.render_template( + "graphiql.html", base_url=self.application.settings["base_url"] ) - execution_result = yield self.execute_graphql_request( - method, query, variables, operation_name, show_graphiql - ) - - status_code = 200 - - if execution_result: - response = {} - - if getattr(execution_result, "is_pending", False): - event = Event() - - def on_resolve(*args): - return event.set() - - execution_result.then(on_resolve).catch(on_resolve) - yield event.wait() - - if hasattr(execution_result, "get"): - execution_result = execution_result.get() - - if execution_result.errors: - response["errors"] = [ - self.format_error(e) for e in execution_result.errors - ] - - if execution_result.invalid: - status_code = 400 - else: - response["data"] = execution_result.data - - if self.batch: - response["id"] = id - response["status"] = status_code - - result = self.json_encode(response, pretty=self.pretty or show_graphiql) - else: - result = None - - raise Return((result, status_code)) - - @coroutine - def execute_graphql_request( - self, method, query, variables, operation_name, show_graphiql=False - ): - if not query: - if show_graphiql: - raise Return(None) - raise HTTPError(400, "Must provide query string.") - - source = Source(query, name="GraphQL request") - - try: - document_ast = parse(source) - validation_errors = validate(self.schema, document_ast) - except Exception as e: - raise Return(ExecutionResult(errors=[e], invalid=True)) - - if validation_errors: - raise Return(ExecutionResult(errors=validation_errors, invalid=True)) - - if method.lower() == "get": - operation_ast = get_operation_ast(document_ast, operation_name) - if operation_ast and operation_ast.operation != "query": - if show_graphiql: - raise Return(None) - - raise HTTPError( - 405, - "Can only perform a {} operation from a POST request.".format( - operation_ast.operation - ), - ) - - try: - result = yield self.execute( - document_ast, - root_value=self.root_value, - variable_values=variables, - operation_name=operation_name, - context_value=self.request, - middleware=self.middleware, - executor=self.executor or TornadoExecutor(), - return_promise=True, - ) - except Exception as e: - raise Return(ExecutionResult(errors=[e], invalid=True)) - - raise Return(result) - - @coroutine - def execute(self, *args, **kwargs): - raise Return(execute(self.schema, *args, **kwargs)) - - def json_encode(self, d, pretty=False): - if pretty or self.get_query_argument("pretty", False): - return json.dumps(d, sort_keys=True, indent=2, separators=(",", ": ")) - - return json.dumps(d, separators=(",", ":")) - - def should_display_graphiql(self): - raw = ( - "raw" in self.request.query_arguments.keys() - or "raw" in self.request.arguments - ) - return not raw and self.request_wants_html() - - def request_wants_html(self): - accept_header = self.request.headers.get("Accept", "") - accept_mimetypes = parse_accept_header(accept_header, MIMEAccept) - best = accept_mimetypes.best_match(["application/json", "text/html"]) - return ( - best == "text/html" - and accept_mimetypes[best] > accept_mimetypes["application/json"] - ) - - @property - def content_type(self): - return self.request.headers.get("Content-Type", "text/plain").split(";")[0] - - @staticmethod - def instantiate_middleware(middlewares): - for middleware in middlewares: - if inspect.isclass(middleware): - yield middleware() - continue - yield middleware - - @staticmethod - def get_graphql_params(request, data): - single_args = {} - for key in request.arguments.keys(): - single_args[key] = request.arguments.get(key)[0] - - query = single_args.get("query") or data.get("query") - variables = single_args.get("variables") or data.get("variables") - id = single_args.get("id") or data.get("id") - - if variables and isinstance(variables, str): - try: - variables = json_decode(variables) - except Exception: - raise HTTPError(400, "Variables are invalid JSON.") - - operation_name = single_args.get("operationName") or data.get("operationName") - if operation_name == "null": - operation_name = None - - return query, variables, operation_name, id - def handle_error(self, ex): - if not isinstance(ex, (web.HTTPError, ExecutionError, GraphQLError)): - tb = "".join(traceback.format_exception(*sys.exc_info())) - app_log.error("Error: {0} {1}".format(ex, tb)) - self.set_status(self.error_status(ex)) - error_json = json_encode({"errors": self.error_format(ex)}) - app_log.debug("error_json: %s", error_json) - self.write(error_json) +# TODO: use notebook websocket stuff +class SubscriptionHandler(websocket.WebSocketHandler): + def initialize(self, subscription_server): + self.subscription_server = subscription_server + self.queue = Queue(100) - @staticmethod - def error_status(exception): - app_log.error(exception) - if isinstance(exception, web.HTTPError): - return exception.status_code - elif isinstance(exception, (ExecutionError, GraphQLError)): - return 400 - else: - return 500 + def select_subprotocol(self, subprotocols): + return GRAPHQL_WS - @staticmethod - def error_format(exception): - if isinstance(exception, ExecutionError): - return [{"message": e} for e in exception.errors] - elif isinstance(exception, GraphQLError): - return [format_graphql_error(exception)] - elif isinstance(exception, web.HTTPError): - return [{"message": exception.log_message}] - else: - return [{"message": "Unknown server error"}] + def open(self): + ioloop.IOLoop.current().spawn_callback(self.subscription_server.handle, self) - @staticmethod - def format_error(error): - if isinstance(error, GraphQLError): - return format_graphql_error(error) + async def on_message(self, message): + await self.queue.put(message) - return {"message": str(error)} + async def recv(self): + return await self.queue.get() diff --git a/src/py/jupyter_graphql/subscriptions.py b/src/py/jupyter_graphql/subscriptions.py new file mode 100644 index 0000000..797acc9 --- /dev/null +++ b/src/py/jupyter_graphql/subscriptions.py @@ -0,0 +1,116 @@ +# Temporary vendoring from +# https://github.com/graphql-python/graphql-ws/blob/cf560b9a5d18d4a3908dc2cfe2199766cc988fef/graphql_ws/tornado.py +from inspect import isawaitable +from asyncio import ensure_future, wait, shield + +from tornado.websocket import WebSocketClosedError + +from graphql.execution.executors.asyncio import AsyncioExecutor + +from graphql_ws.base import ConnectionClosedException, BaseConnectionContext, BaseSubscriptionServer +from graphql_ws.observable_aiter import setup_observable_extension +from graphql_ws.constants import ( + GQL_CONNECTION_ACK, + GQL_CONNECTION_ERROR, + GQL_COMPLETE +) + +setup_observable_extension() + + +class TornadoConnectionContext(BaseConnectionContext): + async def receive(self): + try: + msg = await self.ws.recv() + return msg + except WebSocketClosedError: + raise ConnectionClosedException() + + async def send(self, data): + if self.closed: + return + await self.ws.write_message(data) + + @property + def closed(self): + return self.ws.close_code is not None + + async def close(self, code): + await self.ws.close(code) + + +class TornadoSubscriptionServer(BaseSubscriptionServer): + def __init__(self, schema, keep_alive=True, loop=None): + self.loop = loop + super().__init__(schema, keep_alive) + + def get_graphql_params(self, *args, **kwargs): + params = super(TornadoSubscriptionServer, + self).get_graphql_params(*args, **kwargs) + return dict(params, return_promise=True, executor=AsyncioExecutor(loop=self.loop)) + + async def _handle(self, ws, request_context): + connection_context = TornadoConnectionContext(ws, request_context) + await self.on_open(connection_context) + pending = set() + while True: + try: + if connection_context.closed: + raise ConnectionClosedException() + message = await connection_context.receive() + except ConnectionClosedException: + break + finally: + if pending: + (_, pending) = await wait(pending, timeout=0, loop=self.loop) + + task = ensure_future( + self.on_message(connection_context, message), loop=self.loop) + pending.add(task) + + self.on_close(connection_context) + for task in pending: + task.cancel() + + async def handle(self, ws, request_context=None): + await shield(self._handle(ws, request_context), loop=self.loop) + + async def on_open(self, connection_context): + pass + + def on_close(self, connection_context): + remove_operations = list(connection_context.operations.keys()) + for op_id in remove_operations: + self.unsubscribe(connection_context, op_id) + + async def on_connect(self, connection_context, payload): + pass + + async def on_connection_init(self, connection_context, op_id, payload): + try: + await self.on_connect(connection_context, payload) + await self.send_message(connection_context, op_type=GQL_CONNECTION_ACK) + except Exception as e: + await self.send_error(connection_context, op_id, e, GQL_CONNECTION_ERROR) + await connection_context.close(1011) + + async def on_start(self, connection_context, op_id, params): + execution_result = self.execute( + connection_context.request_context, params) + + if isawaitable(execution_result): + execution_result = await execution_result + + if not hasattr(execution_result, '__aiter__'): + await self.send_execution_result(connection_context, op_id, execution_result) + else: + iterator = await execution_result.__aiter__() + connection_context.register_operation(op_id, iterator) + async for single_result in iterator: + if not connection_context.has_operation(op_id): + break + await self.send_execution_result(connection_context, op_id, single_result) + await self.send_message(connection_context, op_id, GQL_COMPLETE) + + async def on_stop(self, connection_context, op_id): + self.unsubscribe(connection_context, op_id) diff --git a/src/py/jupyter_graphql/templates/graphiql.html b/src/py/jupyter_graphql/templates/graphiql.html index d02a6b0..105f9f7 100644 --- a/src/py/jupyter_graphql/templates/graphiql.html +++ b/src/py/jupyter_graphql/templates/graphiql.html @@ -8,6 +8,7 @@ + - - - - - - - - - - + + + + + + + + - + +