-
Notifications
You must be signed in to change notification settings - Fork 3
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
139 graphql strawberry #141
Draft
paula-mg
wants to merge
3
commits into
main
Choose a base branch
from
139-graphql-strawberry
base: main
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
Draft
Changes from all commits
Commits
Show all changes
3 commits
Select commit
Hold shift + click to select a range
File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,46 @@ | ||
from typing import Any | ||
|
||
import numpy as np | ||
from specs import Line | ||
|
||
from scanspec.core import ( | ||
Frames, | ||
Path, | ||
) | ||
|
||
|
||
def reduce_frames(stack: list[Frames[str]], max_frames: int) -> Path: | ||
"""Removes frames from a spec so len(path) < max_frames. | ||
|
||
Args: | ||
stack: A stack of Frames created by a spec | ||
max_frames: The maximum number of frames the user wishes to be returned | ||
""" | ||
# Calculate the total number of frames | ||
num_frames = 1 | ||
for frames in stack: | ||
num_frames *= len(frames) | ||
|
||
# Need each dim to be this much smaller | ||
ratio = 1 / np.power(max_frames / num_frames, 1 / len(stack)) | ||
|
||
sub_frames = [sub_sample(f, ratio) for f in stack] | ||
return Path(sub_frames) | ||
|
||
|
||
def sub_sample(frames: Frames[str], ratio: float) -> Frames: | ||
"""Provides a sub-sample Frames object whilst preserving its core structure. | ||
|
||
Args: | ||
frames: the Frames object to be reduced | ||
ratio: the reduction ratio of the dimension | ||
""" | ||
num_indexes = int(len(frames) / ratio) | ||
indexes = np.linspace(0, len(frames) - 1, num_indexes, dtype=np.int32) | ||
return frames.extract(indexes, calculate_gap=False) | ||
|
||
|
||
def validate_spec(spec: Line) -> Any: | ||
"""A query used to confirm whether or not the Spec will produce a viable scan.""" | ||
# TODO apischema will do all the validation for us | ||
return spec.serialize() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,107 @@ | ||
from typing import Any | ||
|
||
import strawberry | ||
from fastapi import FastAPI | ||
from resolvers import reduce_frames, validate_spec | ||
from specs import PointsResponse | ||
from strawberry.fastapi import GraphQLRouter | ||
|
||
from scanspec.core import Path | ||
from scanspec.specs import Line, Spec | ||
|
||
# Here is the manual version of what we are trying to do | ||
|
||
# @strawberry.input | ||
# class LineInput(Line): ... | ||
|
||
|
||
# @strawberry.input | ||
# class ZipInput(Zip): ... | ||
|
||
|
||
# @strawberry.input(one_of=True) | ||
# class SpecInput: | ||
# ... | ||
|
||
# line: LineInput | None = strawberry.UNSET | ||
# zip: ZipInput | None = strawberry.UNSET | ||
|
||
|
||
def generate_input_class() -> type[Any]: | ||
# This will be our input class, we're going to fiddle with it | ||
# throughout this function | ||
class SpecInput: ... | ||
|
||
# We want to go through all the possible scan specs, this isn't | ||
# currently possible but can be implemented. | ||
# Raise an issue for a helper function to get all possible scanspec | ||
# types. | ||
for spec_type in Spec.types: | ||
# We make a strawberry input classs using the scanspec pydantic models | ||
# This isn't possible because scanspec models are actually pydantic | ||
# dataclasses. We should have a word with Tom about it and probably | ||
# raise an issue on strawberry. | ||
@strawberry.experimental.pydantic.input(all_fields=True, model=spec_type) | ||
class InputClass: ... | ||
|
||
# Renaming the class to LineInput, ZipInput etc. so the | ||
# schema looks neater | ||
InputClass.__name__ = spec_type.__name__ + "Input" | ||
|
||
# Add a field to the class called line, zip etc. and make it | ||
# strawberry.UNSET | ||
setattr(SpecInput, spec_type.__name__, strawberry.UNSET) | ||
|
||
# Set the type annotation to line | none, zip | none, etc. | ||
# Strawberry will read this and graphqlify it. | ||
SpecInput.__annotations__[spec_type.__name__] = InputClass | None | ||
|
||
# This is just a programtic equivalent of | ||
# @strawberry.input(one_of=True) | ||
# class SpecInput: | ||
# ... | ||
return strawberry.input(one_of=True)(SpecInput) | ||
|
||
|
||
SpecInput = generate_input_class() | ||
|
||
|
||
@strawberry.type | ||
class Query: | ||
@strawberry.field | ||
def validate(self, spec: SpecInput) -> str: | ||
return validate_spec(spec) | ||
|
||
@strawberry.field | ||
def get_points(self, spec: Line, max_frames: int | None = 10000) -> PointsResponse: | ||
"""Calculate the frames present in the scan plus some metadata about the points. | ||
|
||
Args: | ||
spec: The specification of the scan | ||
max_frames: The maximum number of frames the user wishes to receive | ||
""" | ||
|
||
dims = spec.calculate() # Grab dimensions from spec | ||
|
||
path = Path(dims) # Convert to a path | ||
|
||
# TOTAL FRAMES | ||
total_frames = len(path) # Capture the total length of the path | ||
|
||
# MAX FRAMES | ||
# Limit the consumed data by the max_frames argument | ||
if max_frames and (max_frames < len(path)): | ||
# Cap the frames by the max limit | ||
path = reduce_frames(dims, max_frames) | ||
# WARNING: path object is consumed after this statement | ||
chunk = path.consume(max_frames) | ||
|
||
return PointsResponse(chunk, total_frames) | ||
|
||
|
||
schema = strawberry.Schema(Query) | ||
|
||
graphql_app = GraphQLRouter(schema, path="/", graphql_ide="apollo-sandbox") | ||
|
||
app = FastAPI() | ||
app.include_router(graphql_app, prefix="/graphql") |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,93 @@ | ||
from __future__ import annotations | ||
|
||
from _collections_abc import Callable, Mapping | ||
from typing import Any | ||
|
||
import numpy as np | ||
import strawberry | ||
|
||
from scanspec.core import ( | ||
Axis, | ||
Frames, | ||
gap_between_frames, | ||
) | ||
|
||
|
||
@strawberry.type | ||
class PointsResponse: | ||
"""Information about the points provided by a spec.""" | ||
|
||
total_frames: int | ||
returned_frames: int | ||
|
||
def __init__(self, chunk: Frames[str], total_frames: int): | ||
self.total_frames = total_frames | ||
"""The number of frames present across the entire spec""" | ||
self.returned_frames = len(chunk) | ||
"""The number of frames returned by the getPoints query | ||
(controlled by the max_points argument)""" | ||
self._chunk = chunk | ||
|
||
|
||
@strawberry.interface | ||
class SpecInterface: | ||
def serialize(self) -> Mapping[str, Any]: | ||
"""Serialize the spec to a dictionary.""" | ||
return "serialized" | ||
|
||
|
||
def _dimensions_from_indexes( | ||
func: Callable[[np.ndarray], dict[Axis, np.ndarray]], | ||
axes: list, | ||
num: int, | ||
bounds: bool, | ||
) -> list[Frames[Axis]]: | ||
# Calc num midpoints (fences) from 0.5 .. num - 0.5 | ||
midpoints_calc = func(np.linspace(0.5, num - 0.5, num)) | ||
midpoints = {a: midpoints_calc[a] for a in axes} | ||
if bounds: | ||
# Calc num + 1 bounds (posts) from 0 .. num | ||
bounds_calc = func(np.linspace(0, num, num + 1)) | ||
lower = {a: bounds_calc[a][:-1] for a in axes} | ||
upper = {a: bounds_calc[a][1:] for a in axes} | ||
# Points must have no gap as upper[a][i] == lower[a][i+1] | ||
# because we initialized it to be that way | ||
gap = np.zeros(num, dtype=np.bool_) | ||
dimension = Frames(midpoints, lower, upper, gap) | ||
# But calc the first point as difference between first | ||
# and last | ||
gap[0] = gap_between_frames(dimension, dimension) | ||
else: | ||
# Gap can be calculated in Dimension | ||
dimension = Frames(midpoints) | ||
return [dimension] | ||
|
||
|
||
@strawberry.input | ||
class Line(SpecInterface): | ||
axis: str = strawberry.field(description="An identifier for what to move") | ||
start: float = strawberry.field( | ||
description="Midpoint of the first point of the line" | ||
) | ||
stop: float = strawberry.field(description="Midpoint of the last point of the line") | ||
num: int = strawberry.field(description="Number of frames to produce") | ||
|
||
def axes(self) -> list: | ||
return [self.axis] | ||
|
||
def _line_from_indexes(self, indexes: np.ndarray) -> dict[Axis, np.ndarray]: | ||
if self.num == 1: | ||
# Only one point, stop-start gives length of one point | ||
step = self.stop - self.start | ||
else: | ||
# Multiple points, stop-start gives length of num-1 points | ||
step = (self.stop - self.start) / (self.num - 1) | ||
# self.start is the first centre point, but we need the lower bound | ||
# of the first point as this is where the index array starts | ||
first = self.start - step / 2 | ||
return {self.axis: indexes * step + first} | ||
|
||
def calculate(self, bounds=True, nested=False) -> list[Frames[Axis]]: | ||
return _dimensions_from_indexes( | ||
self._line_from_indexes, self.axes(), self.num, bounds | ||
) |
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@callumforrester I am not tied to dataclass, but I have a requirement for positional args. I have a choice of directions for you:
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
To elaborate on 2, this works for both static analysis and at runtime:
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I have no particular preference, we should discuss with @paula-mg since she'll be doing the work.