-
Notifications
You must be signed in to change notification settings - Fork 3
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add dimensionality information to Specs #89
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
|
@@ -11,6 +11,7 @@ | |||||||||||||||||||||||||||||||||||||||||||||||||||||||
List, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||
Optional, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||
Sequence, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||
Tuple, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||
Type, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||
TypeVar, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||
Union, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
@@ -515,6 +516,30 @@ def squash_frames(stack: List[Frames[Axis]], check_path_changes=True) -> Frames[ | |||||||||||||||||||||||||||||||||||||||||||||||||||||||
return squashed | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||
class DimensionInfo(Generic[Axis]): | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||
def __init__( | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||
self, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||
axes: Tuple[Tuple[Axis, ...]], | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||
shape: Tuple[int, ...], | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||
snaked: Tuple[bool, ...] = None, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||
): | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||
self._axes = axes | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||
self._shape = shape | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||
self._snaked = snaked or (False,) * len(shape) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Unsure if this would help, but is it easier to make There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Snaked is False until an inner Dimension when it's True. Maybe it should be the outermost Dimension that is Snaked instead, or a dimension index? |
||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||
@property | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||
def axes(self) -> Tuple[Tuple[Axis, ...]]: | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||
return self._axes | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||
@property | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||
def shape(self) -> Tuple[int, ...]: | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||
return self._shape | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||
@property | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||
def snaked(self) -> Tuple[bool, ...]: | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||
return self._snaked | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||
Comment on lines
+519
to
+542
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Achieves the same thing and more inkeeping with the rest of the codebase
Suggested change
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Also would prefer There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It's not in keeping with the other classes in this file, which is why I didn't make a dataclass. it may be a holdover from worse dataclass Generic handling? I'll see how it behaves and maybe make a PR to move Frame etc. over to be dataclasses: the specs are dataclasses but the Frames etc. are not. RE: Dimensions vs DimensionInfo: Dimension has meaning from SPG that is already present in some of the docs. This object is not a Dimension or collection of Dimensions, it describes some Dimensions. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Could also suggest There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Dimensionality I like There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Good point, none of the other classes in core are dataclasses, but they do all contain a lot of logic, so I think that's correct. This almost purely holds data. |
||||||||||||||||||||||||||||||||||||||||||||||||||||||||
class Path(Generic[Axis]): | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||
"""A consumable route through a stack of Frames, representing a scan path. | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,14 +1,17 @@ | ||
from __future__ import annotations | ||
|
||
from dataclasses import asdict | ||
from functools import reduce | ||
from typing import Any, Callable, Dict, Generic, List, Mapping, Optional, Tuple, Type | ||
from warnings import warn | ||
|
||
import numpy as np | ||
from pydantic import Field, parse_obj_as | ||
from pydantic.dataclasses import dataclass | ||
|
||
from .core import ( | ||
Axis, | ||
DimensionInfo, | ||
Frames, | ||
Midpoints, | ||
Path, | ||
|
@@ -60,7 +63,14 @@ def axes(self) -> List[Axis]: | |
|
||
Ordered from slowest moving to fastest moving. | ||
""" | ||
raise NotImplementedError(self) | ||
warn( | ||
"axes() is deprecated, call dimension_info()", | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Not sure this should be deprecated. I think there are still valid reasons to just say "I want a list of all axes that may be involved in this scan" without caring about/having to unpack the dimensions. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. AIUI Tom requested deprecating axes(), shape(), will have to check when he's back. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Chatted with Tom, happy to deprecate for now and and add a |
||
DeprecationWarning, | ||
stacklevel=2, | ||
) | ||
return reduce( | ||
lambda a, b: a + list(b), self.dimension_info().axes, initial=list() | ||
) | ||
|
||
def calculate(self, bounds=True, nested=False) -> List[Frames[Axis]]: | ||
"""Produce a stack of nested `Frames` that form the scan. | ||
|
@@ -69,6 +79,16 @@ def calculate(self, bounds=True, nested=False) -> List[Frames[Axis]]: | |
""" | ||
raise NotImplementedError(self) | ||
|
||
def dimension_info(self) -> DimensionInfo: | ||
"""Returns the list of axes in each dimension of the scan, | ||
paired with the information on how large each dimension of the scan is, | ||
and whether the dimension is snaked in the dimension outside it. | ||
|
||
Deprecates shape() as does not need to call calculate() | ||
Deprecates axes() as has the per-dimension information | ||
""" | ||
raise NotImplementedError(self) | ||
|
||
def frames(self) -> Frames[Axis]: | ||
"""Expand all the scan `Frames` and return them.""" | ||
return Path(self.calculate()).consume() | ||
|
@@ -79,7 +99,12 @@ def midpoints(self) -> Midpoints[Axis]: | |
|
||
def shape(self) -> Tuple[int, ...]: | ||
"""Return the final, simplified shape of the scan.""" | ||
return tuple(len(dim) for dim in self.calculate()) | ||
warn( | ||
"shape() is deprecated, call dimension_info()", | ||
DeprecationWarning, | ||
stacklevel=2, | ||
) | ||
return self.dimension_info().shape | ||
|
||
def __rmul__(self, other) -> Product[Axis]: | ||
return if_instance_do(other, int, lambda o: Product(Repeat(o), self)) | ||
|
@@ -127,14 +152,20 @@ class Product(Spec[Axis]): | |
outer: Spec[Axis] = Field(description="Will be executed once") | ||
inner: Spec[Axis] = Field(description="Will be executed len(outer) times") | ||
|
||
def axes(self) -> List: | ||
return self.outer.axes() + self.inner.axes() | ||
|
||
def calculate(self, bounds=True, nested=False) -> List[Frames[Axis]]: | ||
frames_outer = self.outer.calculate(bounds=False, nested=nested) | ||
frames_inner = self.inner.calculate(bounds, nested=True) | ||
return frames_outer + frames_inner | ||
|
||
def dimension_info(self) -> DimensionInfo: | ||
outer_info = self.outer.dimension_info() | ||
inner_info = self.inner.dimension_info() | ||
return DimensionInfo( | ||
axes=outer_info.axes + inner_info.axes, | ||
shape=outer_info.shape + inner_info.shape, | ||
snaked=outer_info.snaked + inner_info.snaked, | ||
) | ||
|
||
|
||
@dataclass(config=StrictConfig) | ||
class Repeat(Spec[Axis]): | ||
|
@@ -166,12 +197,12 @@ class Repeat(Spec[Axis]): | |
default=True, | ||
) | ||
|
||
def axes(self) -> List: | ||
return [] | ||
|
||
def calculate(self, bounds=True, nested=False) -> List[Frames[Axis]]: | ||
return [Frames({}, gap=np.full(self.num, self.gap))] | ||
|
||
def dimension_info(self) -> DimensionInfo: | ||
return DimensionInfo(axes=((DURATION,),), shape=(self.num,)) | ||
|
||
|
||
@dataclass(config=StrictConfig) | ||
class Zip(Spec[Axis]): | ||
|
@@ -203,9 +234,6 @@ class Zip(Spec[Axis]): | |
description="The right-hand Spec to Zip, will appear later in axes" | ||
) | ||
|
||
def axes(self) -> List: | ||
return self.left.axes() + self.right.axes() | ||
|
||
def calculate(self, bounds=True, nested=False) -> List[Frames[Axis]]: | ||
frames_left = self.left.calculate(bounds, nested) | ||
frames_right = self.right.calculate(bounds, nested) | ||
|
@@ -243,6 +271,24 @@ def calculate(self, bounds=True, nested=False) -> List[Frames[Axis]]: | |
frames.append(combined) | ||
return frames | ||
|
||
def dimension_info(self) -> DimensionInfo: | ||
left_info = self.left.dimension_info() | ||
right_info = self.right.dimension_info() | ||
left = left_info.axes | ||
padded_right = ((None,),) * ( | ||
len(left_info.axes) - len(right_info.axes) | ||
) + right_info.axes | ||
axes = tuple( | ||
left[i] if padded_right[i] == (None,) else left[i] + padded_right[i] | ||
for i in range(len(left_info.axes)) | ||
) | ||
|
||
return DimensionInfo( | ||
axes=axes, | ||
shape=left_info.shape, | ||
snaked=left_info.snaked, # Non-matching Snake axes cannot be Zipped | ||
) | ||
|
||
|
||
@dataclass(config=StrictConfig) | ||
class Mask(Spec[Axis]): | ||
|
@@ -271,9 +317,6 @@ class Mask(Spec[Axis]): | |
default=True, | ||
) | ||
|
||
def axes(self) -> List: | ||
return self.spec.axes() | ||
|
||
def calculate(self, bounds=True, nested=False) -> List[Frames[Axis]]: | ||
frames = self.spec.calculate(bounds, nested) | ||
for axis_set in self.region.axis_sets(): | ||
|
@@ -295,6 +338,21 @@ def calculate(self, bounds=True, nested=False) -> List[Frames[Axis]]: | |
masked_frames.append(f.extract(indices)) | ||
return masked_frames | ||
|
||
def dimension_info(self) -> DimensionInfo: | ||
""" | ||
As Mask applies a Region to a Spec, which may alter the Spec, but this is not | ||
knowable without calculating the entire Spec, we have to calculate the Spec | ||
here. | ||
Currently we throw away the results of this calculation, but in future we may | ||
want to cache the result, or else modify the behaviour of this method generally | ||
to match. | ||
""" | ||
frames = self.calculate(bounds=False, nested=False) | ||
shape = tuple(len(x.midpoints) for x in frames) | ||
axes = tuple(tuple(x.axes()) for x in frames) | ||
callumforrester marked this conversation as resolved.
Show resolved
Hide resolved
|
||
snaked = tuple(isinstance(x, SnakedFrames) for x in frames) | ||
return DimensionInfo(axes=axes, shape=shape, snaked=snaked) | ||
|
||
# *+ bind more tightly than &|^ so without these overrides we | ||
# would need to add brackets to all combinations of Regions | ||
def __or__(self, other: Region[Axis]) -> Mask[Axis]: | ||
|
@@ -329,15 +387,21 @@ class Snake(Spec[Axis]): | |
description="The Spec to run in reverse every other iteration" | ||
) | ||
|
||
def axes(self) -> List: | ||
return self.spec.axes() | ||
|
||
def calculate(self, bounds=True, nested=False) -> List[Frames[Axis]]: | ||
return [ | ||
SnakedFrames.from_frames(segment) | ||
for segment in self.spec.calculate(bounds, nested) | ||
] | ||
|
||
def dimension_info(self) -> DimensionInfo: | ||
spec_info = self.spec.dimension_info() | ||
return DimensionInfo( | ||
axes=spec_info.axes, | ||
shape=spec_info.shape, | ||
snaked=(True,) * len(spec_info.shape), | ||
) | ||
return self.spec.dimension_info() | ||
|
||
|
||
@dataclass(config=StrictConfig) | ||
class Concat(Spec[Axis]): | ||
|
@@ -368,13 +432,6 @@ class Concat(Spec[Axis]): | |
default=True, | ||
) | ||
|
||
def axes(self) -> List: | ||
left_axes, right_axes = self.left.axes(), self.right.axes() | ||
# Assuming the axes are the same, the order does not matter, we inherit the | ||
# order from the left-hand side. See also scanspec.core.concat. | ||
assert set(left_axes) == set(right_axes), f"axes {left_axes} != {right_axes}" | ||
return left_axes | ||
|
||
def calculate(self, bounds=True, nested=False) -> List[Frames[Axis]]: | ||
dim_left = squash_frames( | ||
self.left.calculate(bounds, nested), nested and self.check_path_changes | ||
|
@@ -385,6 +442,19 @@ def calculate(self, bounds=True, nested=False) -> List[Frames[Axis]]: | |
dim = dim_left.concat(dim_right, self.gap) | ||
return [dim] | ||
|
||
def dimension_info(self) -> DimensionInfo: | ||
left_info = self.left.dimension_info() | ||
right_info = self.right.dimension_info() | ||
assert left_info.axes == right_info.axes | ||
# We will squash each spec into 1 dimension | ||
left_size = reduce(lambda a, b: a * b, left_info.shape) | ||
right_size = reduce(lambda a, b: a * b, right_info.shape) | ||
return DimensionInfo( | ||
axes=left_info.axes, | ||
shape=(left_size + right_size,), | ||
snaked=left_info.snaked, # Non-matching Snake axes cannot be Concat | ||
) | ||
|
||
|
||
@dataclass(config=StrictConfig) | ||
class Squash(Spec[Axis]): | ||
|
@@ -406,14 +476,18 @@ class Squash(Spec[Axis]): | |
default=True, | ||
) | ||
|
||
def axes(self) -> List: | ||
return self.spec.axes() | ||
|
||
def calculate(self, bounds=True, nested=False) -> List[Frames[Axis]]: | ||
dims = self.spec.calculate(bounds, nested) | ||
dim = squash_frames(dims, nested and self.check_path_changes) | ||
return [dim] | ||
|
||
def dimension_info(self) -> DimensionInfo: | ||
spec_info = self.spec.dimension_info() | ||
return DimensionInfo( | ||
axes=(reduce(lambda a, b: a + b, spec_info.axes),), | ||
shape=(reduce(lambda a, b: a * b, spec_info.shape),), | ||
) | ||
|
||
|
||
def _dimensions_from_indexes( | ||
func: Callable[[np.ndarray], Dict[Axis, np.ndarray]], | ||
|
@@ -458,8 +532,8 @@ class Line(Spec[Axis]): | |
stop: float = Field(description="Midpoint of the last point of the line") | ||
num: int = Field(min=1, description="Number of frames to produce") | ||
|
||
def axes(self) -> List: | ||
return [self.axis] | ||
def dimension_info(self) -> DimensionInfo: | ||
return DimensionInfo(axes=((self.axis,),), shape=(self.num,)) | ||
|
||
def _line_from_indexes(self, indexes: np.ndarray) -> Dict[Axis, np.ndarray]: | ||
if self.num == 1: | ||
|
@@ -475,7 +549,7 @@ def _line_from_indexes(self, indexes: np.ndarray) -> Dict[Axis, np.ndarray]: | |
|
||
def calculate(self, bounds=True, nested=False) -> List[Frames[Axis]]: | ||
return _dimensions_from_indexes( | ||
self._line_from_indexes, self.axes(), self.num, bounds | ||
self._line_from_indexes, [self.axis], self.num, bounds | ||
) | ||
|
||
@classmethod | ||
|
@@ -538,17 +612,17 @@ def duration( | |
""" | ||
return cls(DURATION, duration, num) | ||
|
||
def axes(self) -> List: | ||
return [self.axis] | ||
|
||
def _repeats_from_indexes(self, indexes: np.ndarray) -> Dict[Axis, np.ndarray]: | ||
return {self.axis: np.full(len(indexes), self.value)} | ||
|
||
def calculate(self, bounds=True, nested=False) -> List[Frames[Axis]]: | ||
return _dimensions_from_indexes( | ||
self._repeats_from_indexes, self.axes(), self.num, bounds | ||
self._repeats_from_indexes, [self.axis], self.num, bounds | ||
) | ||
|
||
def dimension_info(self) -> DimensionInfo: | ||
return DimensionInfo(axes=((self.axis,),), shape=(self.num,)) | ||
|
||
|
||
@dataclass(config=StrictConfig) | ||
class Spiral(Spec[Axis]): | ||
|
@@ -577,9 +651,12 @@ class Spiral(Spec[Axis]): | |
description="How much to rotate the angle of the spiral", default=0.0 | ||
) | ||
|
||
def axes(self) -> List[Axis]: | ||
# TODO: reversed from __init__ args, a good idea? | ||
return [self.y_axis, self.x_axis] | ||
def dimension_info(self) -> DimensionInfo: | ||
return DimensionInfo( | ||
# TODO: reversed from __init__ args, a good idea? | ||
axes=((self.y_axis, self.x_axis),), | ||
shape=(self.num,), | ||
) | ||
|
||
def _spiral_from_indexes(self, indexes: np.ndarray) -> Dict[Axis, np.ndarray]: | ||
# simplest spiral equation: r = phi | ||
|
@@ -600,7 +677,7 @@ def _spiral_from_indexes(self, indexes: np.ndarray) -> Dict[Axis, np.ndarray]: | |
|
||
def calculate(self, bounds=True, nested=False) -> List[Frames[Axis]]: | ||
return _dimensions_from_indexes( | ||
self._spiral_from_indexes, self.axes(), self.num, bounds | ||
self._spiral_from_indexes, [self.y_axis, self.x_axis], self.num, bounds | ||
callumforrester marked this conversation as resolved.
Show resolved
Hide resolved
|
||
) | ||
|
||
@classmethod | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.