Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

139 graphql strawberry #141

Draft
wants to merge 3 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 3 additions & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,6 @@ requires-python = ">=3.10"
plotting = ["scipy", "matplotlib"]
# REST service support
service = ["fastapi>=0.100.0", "uvicorn"]
# For development tests/docs
dev = [
# This syntax is supported since pip 21.2
# https://github.com/pypa/pip/issues/10393
Expand All @@ -41,11 +40,13 @@ dev = [
"sphinx-copybutton",
"sphinx-design",
"sphinxcontrib-openapi",
"strawberry-graphql[debug-server]",
"strawberry-graphql[fastapi]",
"tox-direct",
"types-mock",
"httpx",
"myst-parser",
]
] # For development tests/docs

[project.scripts]
scanspec = "scanspec.cli:cli"
Expand Down
Empty file added src/scanspec/schema/__init__.py
Empty file.
46 changes: 46 additions & 0 deletions src/scanspec/schema/resolvers.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
from typing import Any

import numpy as np
from specs import Line

from scanspec.core import (
Frames,
Path,
)


def reduce_frames(stack: list[Frames[str]], max_frames: int) -> Path:
"""Removes frames from a spec so len(path) < max_frames.

Args:
stack: A stack of Frames created by a spec
max_frames: The maximum number of frames the user wishes to be returned
"""
# Calculate the total number of frames
num_frames = 1
for frames in stack:
num_frames *= len(frames)

# Need each dim to be this much smaller
ratio = 1 / np.power(max_frames / num_frames, 1 / len(stack))

sub_frames = [sub_sample(f, ratio) for f in stack]
return Path(sub_frames)


def sub_sample(frames: Frames[str], ratio: float) -> Frames:
"""Provides a sub-sample Frames object whilst preserving its core structure.

Args:
frames: the Frames object to be reduced
ratio: the reduction ratio of the dimension
"""
num_indexes = int(len(frames) / ratio)
indexes = np.linspace(0, len(frames) - 1, num_indexes, dtype=np.int32)
return frames.extract(indexes, calculate_gap=False)


def validate_spec(spec: Line) -> Any:
"""A query used to confirm whether or not the Spec will produce a viable scan."""
# TODO apischema will do all the validation for us
return spec.serialize()
107 changes: 107 additions & 0 deletions src/scanspec/schema/schema.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,107 @@
from typing import Any

import strawberry
from fastapi import FastAPI
from resolvers import reduce_frames, validate_spec
from specs import PointsResponse
from strawberry.fastapi import GraphQLRouter

from scanspec.core import Path
from scanspec.specs import Line, Spec

# Here is the manual version of what we are trying to do

# @strawberry.input
# class LineInput(Line): ...


# @strawberry.input
# class ZipInput(Zip): ...


# @strawberry.input(one_of=True)
# class SpecInput:
# ...

# line: LineInput | None = strawberry.UNSET
# zip: ZipInput | None = strawberry.UNSET


def generate_input_class() -> type[Any]:
# This will be our input class, we're going to fiddle with it
# throughout this function
class SpecInput: ...

# We want to go through all the possible scan specs, this isn't
# currently possible but can be implemented.
# Raise an issue for a helper function to get all possible scanspec
# types.
for spec_type in Spec.types:
# We make a strawberry input classs using the scanspec pydantic models
# This isn't possible because scanspec models are actually pydantic
# dataclasses. We should have a word with Tom about it and probably
# raise an issue on strawberry.
Comment on lines +40 to +43
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@callumforrester I am not tied to dataclass, but I have a requirement for positional args. I have a choice of directions for you:

  1. Continue using dataclasses, add support to strawberry, probably using something like what I needed to do to autodoc_pydantic
  2. Ditch dataclasses and make a BaseModel subclass with positional arg support. I think we could make this work both at runtime using something like this and at static analysis by overriding the dataclass_transform. The closed issue would suggest that pydantic would never accept such an approach upstream.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

To elaborate on 2, this works for both static analysis and at runtime:

from __future__ import annotations

from abc import ABCMeta
from typing import Any

from pydantic import BaseModel, Field
from typing_extensions import dataclass_transform

# TODO: not sure about the others like NoInitField and PrivateAttr
@dataclass_transform(field_specifiers=(Field,))
class PosargsMeta(type(BaseModel), ABCMeta):
    def __new__(
        mcs,
        cls_name: str,
        bases: tuple[type[Any], ...],
        namespace: dict[str, Any],
        **kwargs: Any,
    ) -> type:
        cls = super().__new__(mcs, cls_name, bases, namespace, **kwargs)
        original_init = cls.__init__

        def patched_init(self, *args, **kwargs):
            for k, v in zip(cls.model_fields, args, strict=False):
                kwargs[k] = v
            original_init(self, **kwargs)

        cls.__init__ = patched_init
        return cls

class Spec(BaseModel, metaclass=PosargsMeta):
    pass

class Line(Spec):
    start: float = Field(description="Midpoint of the first point of the line")
    stop: float = Field(description="Midpoint of the last point of the line")
    num: int = Field(min=1, description="Number of frames to produce")

# pyright and pydantic are happy with this...
obj = Line(3, 4, 5)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I have no particular preference, we should discuss with @paula-mg since she'll be doing the work.

@strawberry.experimental.pydantic.input(all_fields=True, model=spec_type)
class InputClass: ...

# Renaming the class to LineInput, ZipInput etc. so the
# schema looks neater
InputClass.__name__ = spec_type.__name__ + "Input"

# Add a field to the class called line, zip etc. and make it
# strawberry.UNSET
setattr(SpecInput, spec_type.__name__, strawberry.UNSET)

# Set the type annotation to line | none, zip | none, etc.
# Strawberry will read this and graphqlify it.
SpecInput.__annotations__[spec_type.__name__] = InputClass | None

# This is just a programtic equivalent of
# @strawberry.input(one_of=True)
# class SpecInput:
# ...
return strawberry.input(one_of=True)(SpecInput)


SpecInput = generate_input_class()


@strawberry.type
class Query:
@strawberry.field
def validate(self, spec: SpecInput) -> str:
return validate_spec(spec)

@strawberry.field
def get_points(self, spec: Line, max_frames: int | None = 10000) -> PointsResponse:
"""Calculate the frames present in the scan plus some metadata about the points.

Args:
spec: The specification of the scan
max_frames: The maximum number of frames the user wishes to receive
"""

dims = spec.calculate() # Grab dimensions from spec

path = Path(dims) # Convert to a path

# TOTAL FRAMES
total_frames = len(path) # Capture the total length of the path

# MAX FRAMES
# Limit the consumed data by the max_frames argument
if max_frames and (max_frames < len(path)):
# Cap the frames by the max limit
path = reduce_frames(dims, max_frames)
# WARNING: path object is consumed after this statement
chunk = path.consume(max_frames)

return PointsResponse(chunk, total_frames)


schema = strawberry.Schema(Query)

graphql_app = GraphQLRouter(schema, path="/", graphql_ide="apollo-sandbox")

app = FastAPI()
app.include_router(graphql_app, prefix="/graphql")
93 changes: 93 additions & 0 deletions src/scanspec/schema/specs.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,93 @@
from __future__ import annotations

from _collections_abc import Callable, Mapping
from typing import Any

import numpy as np
import strawberry

from scanspec.core import (
Axis,
Frames,
gap_between_frames,
)


@strawberry.type
class PointsResponse:
"""Information about the points provided by a spec."""

total_frames: int
returned_frames: int

def __init__(self, chunk: Frames[str], total_frames: int):
self.total_frames = total_frames
"""The number of frames present across the entire spec"""
self.returned_frames = len(chunk)
"""The number of frames returned by the getPoints query
(controlled by the max_points argument)"""
self._chunk = chunk


@strawberry.interface
class SpecInterface:
def serialize(self) -> Mapping[str, Any]:
"""Serialize the spec to a dictionary."""
return "serialized"


def _dimensions_from_indexes(
func: Callable[[np.ndarray], dict[Axis, np.ndarray]],
axes: list,
num: int,
bounds: bool,
) -> list[Frames[Axis]]:
# Calc num midpoints (fences) from 0.5 .. num - 0.5
midpoints_calc = func(np.linspace(0.5, num - 0.5, num))
midpoints = {a: midpoints_calc[a] for a in axes}
if bounds:
# Calc num + 1 bounds (posts) from 0 .. num
bounds_calc = func(np.linspace(0, num, num + 1))
lower = {a: bounds_calc[a][:-1] for a in axes}
upper = {a: bounds_calc[a][1:] for a in axes}
# Points must have no gap as upper[a][i] == lower[a][i+1]
# because we initialized it to be that way
gap = np.zeros(num, dtype=np.bool_)
dimension = Frames(midpoints, lower, upper, gap)
# But calc the first point as difference between first
# and last
gap[0] = gap_between_frames(dimension, dimension)
else:
# Gap can be calculated in Dimension
dimension = Frames(midpoints)
return [dimension]


@strawberry.input
class Line(SpecInterface):
axis: str = strawberry.field(description="An identifier for what to move")
start: float = strawberry.field(
description="Midpoint of the first point of the line"
)
stop: float = strawberry.field(description="Midpoint of the last point of the line")
num: int = strawberry.field(description="Number of frames to produce")

def axes(self) -> list:
return [self.axis]

def _line_from_indexes(self, indexes: np.ndarray) -> dict[Axis, np.ndarray]:
if self.num == 1:
# Only one point, stop-start gives length of one point
step = self.stop - self.start
else:
# Multiple points, stop-start gives length of num-1 points
step = (self.stop - self.start) / (self.num - 1)
# self.start is the first centre point, but we need the lower bound
# of the first point as this is where the index array starts
first = self.start - step / 2
return {self.axis: indexes * step + first}

def calculate(self, bounds=True, nested=False) -> list[Frames[Axis]]:
return _dimensions_from_indexes(
self._line_from_indexes, self.axes(), self.num, bounds
)
Loading