Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add dimensionality information to Specs #89

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 25 additions & 0 deletions src/scanspec/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
List,
Optional,
Sequence,
Tuple,
Type,
TypeVar,
Union,
Expand Down Expand Up @@ -515,6 +516,30 @@ def squash_frames(stack: List[Frames[Axis]], check_path_changes=True) -> Frames[
return squashed


class DimensionInfo(Generic[Axis]):
def __init__(
self,
axes: Tuple[Tuple[Axis, ...]],
shape: Tuple[int, ...],
snaked: Tuple[bool, ...] = None,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
snaked: Tuple[bool, ...] = None,
snaked: Optional[Tuple[bool, ...]] = None,

):
self._axes = axes
self._shape = shape
self._snaked = snaked or (False,) * len(shape)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Unsure if this would help, but is it easier to make snaked an enum of yes, no and unknown?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Snaked is False until an inner Dimension when it's True. Maybe it should be the outermost Dimension that is Snaked instead, or a dimension index?


@property
def axes(self) -> Tuple[Tuple[Axis, ...]]:
return self._axes

@property
def shape(self) -> Tuple[int, ...]:
return self._shape

@property
def snaked(self) -> Tuple[bool, ...]:
return self._snaked


Comment on lines +519 to +542
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Achieves the same thing and more inkeeping with the rest of the codebase

Suggested change
class DimensionInfo(Generic[Axis]):
def __init__(
self,
axes: Tuple[Tuple[Axis, ...]],
shape: Tuple[int, ...],
snaked: Tuple[bool, ...] = None,
):
self._axes = axes
self._shape = shape
self._snaked = snaked or (False,) * len(shape)
@property
def axes(self) -> Tuple[Tuple[Axis, ...]]:
return self._axes
@property
def shape(self) -> Tuple[int, ...]:
return self._shape
@property
def snaked(self) -> Tuple[bool, ...]:
return self._snaked
@dataclass(frozen=True)
class DimensionInfo:
axes: Tuple[Tuple[Axis, ...]]
shape: Tuple[int, ...]
snaked: Tuple[bool, ...]

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also would prefer Dimensions to DimensionInfo, because that's also a similar naming convention to other classes

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's not in keeping with the other classes in this file, which is why I didn't make a dataclass. it may be a holdover from worse dataclass Generic handling? I'll see how it behaves and maybe make a PR to move Frame etc. over to be dataclasses: the specs are dataclasses but the Frames etc. are not.

RE: Dimensions vs DimensionInfo: Dimension has meaning from SPG that is already present in some of the docs. This object is not a Dimension or collection of Dimensions, it describes some Dimensions.
Each component of axes/shape/snaked refers to the behaviour of a single Dimension. DimensionsInfo for maybe

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could also suggest Dimensionality

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Dimensionality I like

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good point, none of the other classes in core are dataclasses, but they do all contain a lot of logic, so I think that's correct. This almost purely holds data.
From discussion with Tom, he's leaning towards suggesting that there be no default for snaked and that be dealt with downstream. That would then make this a pure data class.

class Path(Generic[Axis]):
"""A consumable route through a stack of Frames, representing a scan path.

Expand Down
153 changes: 115 additions & 38 deletions src/scanspec/specs.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,17 @@
from __future__ import annotations

from dataclasses import asdict
from functools import reduce
from typing import Any, Callable, Dict, Generic, List, Mapping, Optional, Tuple, Type
from warnings import warn

import numpy as np
from pydantic import Field, parse_obj_as
from pydantic.dataclasses import dataclass

from .core import (
Axis,
DimensionInfo,
Frames,
Midpoints,
Path,
Expand Down Expand Up @@ -60,7 +63,14 @@ def axes(self) -> List[Axis]:

Ordered from slowest moving to fastest moving.
"""
raise NotImplementedError(self)
warn(
"axes() is deprecated, call dimension_info()",
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not sure this should be deprecated. I think there are still valid reasons to just say "I want a list of all axes that may be involved in this scan" without caring about/having to unpack the dimensions.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

AIUI Tom requested deprecating axes(), shape(), will have to check when he's back.
The handling to get from the info is there, so it's easy to leave them both in and not worry about it again.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Chatted with Tom, happy to deprecate for now and and add a .flat_axes() to Dimensionality if a need arises

DeprecationWarning,
stacklevel=2,
)
return reduce(
lambda a, b: a + list(b), self.dimension_info().axes, initial=list()
)

def calculate(self, bounds=True, nested=False) -> List[Frames[Axis]]:
"""Produce a stack of nested `Frames` that form the scan.
Expand All @@ -69,6 +79,16 @@ def calculate(self, bounds=True, nested=False) -> List[Frames[Axis]]:
"""
raise NotImplementedError(self)

def dimension_info(self) -> DimensionInfo:
"""Returns the list of axes in each dimension of the scan,
paired with the information on how large each dimension of the scan is,
and whether the dimension is snaked in the dimension outside it.

Deprecates shape() as does not need to call calculate()
Deprecates axes() as has the per-dimension information
"""
raise NotImplementedError(self)

def frames(self) -> Frames[Axis]:
"""Expand all the scan `Frames` and return them."""
return Path(self.calculate()).consume()
Expand All @@ -79,7 +99,12 @@ def midpoints(self) -> Midpoints[Axis]:

def shape(self) -> Tuple[int, ...]:
"""Return the final, simplified shape of the scan."""
return tuple(len(dim) for dim in self.calculate())
warn(
"shape() is deprecated, call dimension_info()",
DeprecationWarning,
stacklevel=2,
)
return self.dimension_info().shape

def __rmul__(self, other) -> Product[Axis]:
return if_instance_do(other, int, lambda o: Product(Repeat(o), self))
Expand Down Expand Up @@ -127,14 +152,20 @@ class Product(Spec[Axis]):
outer: Spec[Axis] = Field(description="Will be executed once")
inner: Spec[Axis] = Field(description="Will be executed len(outer) times")

def axes(self) -> List:
return self.outer.axes() + self.inner.axes()

def calculate(self, bounds=True, nested=False) -> List[Frames[Axis]]:
frames_outer = self.outer.calculate(bounds=False, nested=nested)
frames_inner = self.inner.calculate(bounds, nested=True)
return frames_outer + frames_inner

def dimension_info(self) -> DimensionInfo:
outer_info = self.outer.dimension_info()
inner_info = self.inner.dimension_info()
return DimensionInfo(
axes=outer_info.axes + inner_info.axes,
shape=outer_info.shape + inner_info.shape,
snaked=outer_info.snaked + inner_info.snaked,
)


@dataclass(config=StrictConfig)
class Repeat(Spec[Axis]):
Expand Down Expand Up @@ -166,12 +197,12 @@ class Repeat(Spec[Axis]):
default=True,
)

def axes(self) -> List:
return []

def calculate(self, bounds=True, nested=False) -> List[Frames[Axis]]:
return [Frames({}, gap=np.full(self.num, self.gap))]

def dimension_info(self) -> DimensionInfo:
return DimensionInfo(axes=((DURATION,),), shape=(self.num,))


@dataclass(config=StrictConfig)
class Zip(Spec[Axis]):
Expand Down Expand Up @@ -203,9 +234,6 @@ class Zip(Spec[Axis]):
description="The right-hand Spec to Zip, will appear later in axes"
)

def axes(self) -> List:
return self.left.axes() + self.right.axes()

def calculate(self, bounds=True, nested=False) -> List[Frames[Axis]]:
frames_left = self.left.calculate(bounds, nested)
frames_right = self.right.calculate(bounds, nested)
Expand Down Expand Up @@ -243,6 +271,24 @@ def calculate(self, bounds=True, nested=False) -> List[Frames[Axis]]:
frames.append(combined)
return frames

def dimension_info(self) -> DimensionInfo:
left_info = self.left.dimension_info()
right_info = self.right.dimension_info()
left = left_info.axes
padded_right = ((None,),) * (
len(left_info.axes) - len(right_info.axes)
) + right_info.axes
axes = tuple(
left[i] if padded_right[i] == (None,) else left[i] + padded_right[i]
for i in range(len(left_info.axes))
)

return DimensionInfo(
axes=axes,
shape=left_info.shape,
snaked=left_info.snaked, # Non-matching Snake axes cannot be Zipped
)


@dataclass(config=StrictConfig)
class Mask(Spec[Axis]):
Expand Down Expand Up @@ -271,9 +317,6 @@ class Mask(Spec[Axis]):
default=True,
)

def axes(self) -> List:
return self.spec.axes()

def calculate(self, bounds=True, nested=False) -> List[Frames[Axis]]:
frames = self.spec.calculate(bounds, nested)
for axis_set in self.region.axis_sets():
Expand All @@ -295,6 +338,21 @@ def calculate(self, bounds=True, nested=False) -> List[Frames[Axis]]:
masked_frames.append(f.extract(indices))
return masked_frames

def dimension_info(self) -> DimensionInfo:
"""
As Mask applies a Region to a Spec, which may alter the Spec, but this is not
knowable without calculating the entire Spec, we have to calculate the Spec
here.
Currently we throw away the results of this calculation, but in future we may
want to cache the result, or else modify the behaviour of this method generally
to match.
"""
frames = self.calculate(bounds=False, nested=False)
shape = tuple(len(x.midpoints) for x in frames)
axes = tuple(tuple(x.axes()) for x in frames)
callumforrester marked this conversation as resolved.
Show resolved Hide resolved
snaked = tuple(isinstance(x, SnakedFrames) for x in frames)
return DimensionInfo(axes=axes, shape=shape, snaked=snaked)

# *+ bind more tightly than &|^ so without these overrides we
# would need to add brackets to all combinations of Regions
def __or__(self, other: Region[Axis]) -> Mask[Axis]:
Expand Down Expand Up @@ -329,15 +387,21 @@ class Snake(Spec[Axis]):
description="The Spec to run in reverse every other iteration"
)

def axes(self) -> List:
return self.spec.axes()

def calculate(self, bounds=True, nested=False) -> List[Frames[Axis]]:
return [
SnakedFrames.from_frames(segment)
for segment in self.spec.calculate(bounds, nested)
]

def dimension_info(self) -> DimensionInfo:
spec_info = self.spec.dimension_info()
return DimensionInfo(
axes=spec_info.axes,
shape=spec_info.shape,
snaked=(True,) * len(spec_info.shape),
)
return self.spec.dimension_info()


@dataclass(config=StrictConfig)
class Concat(Spec[Axis]):
Expand Down Expand Up @@ -368,13 +432,6 @@ class Concat(Spec[Axis]):
default=True,
)

def axes(self) -> List:
left_axes, right_axes = self.left.axes(), self.right.axes()
# Assuming the axes are the same, the order does not matter, we inherit the
# order from the left-hand side. See also scanspec.core.concat.
assert set(left_axes) == set(right_axes), f"axes {left_axes} != {right_axes}"
return left_axes

def calculate(self, bounds=True, nested=False) -> List[Frames[Axis]]:
dim_left = squash_frames(
self.left.calculate(bounds, nested), nested and self.check_path_changes
Expand All @@ -385,6 +442,19 @@ def calculate(self, bounds=True, nested=False) -> List[Frames[Axis]]:
dim = dim_left.concat(dim_right, self.gap)
return [dim]

def dimension_info(self) -> DimensionInfo:
left_info = self.left.dimension_info()
right_info = self.right.dimension_info()
assert left_info.axes == right_info.axes
# We will squash each spec into 1 dimension
left_size = reduce(lambda a, b: a * b, left_info.shape)
right_size = reduce(lambda a, b: a * b, right_info.shape)
return DimensionInfo(
axes=left_info.axes,
shape=(left_size + right_size,),
snaked=left_info.snaked, # Non-matching Snake axes cannot be Concat
)


@dataclass(config=StrictConfig)
class Squash(Spec[Axis]):
Expand All @@ -406,14 +476,18 @@ class Squash(Spec[Axis]):
default=True,
)

def axes(self) -> List:
return self.spec.axes()

def calculate(self, bounds=True, nested=False) -> List[Frames[Axis]]:
dims = self.spec.calculate(bounds, nested)
dim = squash_frames(dims, nested and self.check_path_changes)
return [dim]

def dimension_info(self) -> DimensionInfo:
spec_info = self.spec.dimension_info()
return DimensionInfo(
axes=(reduce(lambda a, b: a + b, spec_info.axes),),
shape=(reduce(lambda a, b: a * b, spec_info.shape),),
)


def _dimensions_from_indexes(
func: Callable[[np.ndarray], Dict[Axis, np.ndarray]],
Expand Down Expand Up @@ -458,8 +532,8 @@ class Line(Spec[Axis]):
stop: float = Field(description="Midpoint of the last point of the line")
num: int = Field(min=1, description="Number of frames to produce")

def axes(self) -> List:
return [self.axis]
def dimension_info(self) -> DimensionInfo:
return DimensionInfo(axes=((self.axis,),), shape=(self.num,))

def _line_from_indexes(self, indexes: np.ndarray) -> Dict[Axis, np.ndarray]:
if self.num == 1:
Expand All @@ -475,7 +549,7 @@ def _line_from_indexes(self, indexes: np.ndarray) -> Dict[Axis, np.ndarray]:

def calculate(self, bounds=True, nested=False) -> List[Frames[Axis]]:
return _dimensions_from_indexes(
self._line_from_indexes, self.axes(), self.num, bounds
self._line_from_indexes, [self.axis], self.num, bounds
)

@classmethod
Expand Down Expand Up @@ -538,17 +612,17 @@ def duration(
"""
return cls(DURATION, duration, num)

def axes(self) -> List:
return [self.axis]

def _repeats_from_indexes(self, indexes: np.ndarray) -> Dict[Axis, np.ndarray]:
return {self.axis: np.full(len(indexes), self.value)}

def calculate(self, bounds=True, nested=False) -> List[Frames[Axis]]:
return _dimensions_from_indexes(
self._repeats_from_indexes, self.axes(), self.num, bounds
self._repeats_from_indexes, [self.axis], self.num, bounds
)

def dimension_info(self) -> DimensionInfo:
return DimensionInfo(axes=((self.axis,),), shape=(self.num,))


@dataclass(config=StrictConfig)
class Spiral(Spec[Axis]):
Expand Down Expand Up @@ -577,9 +651,12 @@ class Spiral(Spec[Axis]):
description="How much to rotate the angle of the spiral", default=0.0
)

def axes(self) -> List[Axis]:
# TODO: reversed from __init__ args, a good idea?
return [self.y_axis, self.x_axis]
def dimension_info(self) -> DimensionInfo:
return DimensionInfo(
# TODO: reversed from __init__ args, a good idea?
axes=((self.y_axis, self.x_axis),),
shape=(self.num,),
)

def _spiral_from_indexes(self, indexes: np.ndarray) -> Dict[Axis, np.ndarray]:
# simplest spiral equation: r = phi
Expand All @@ -600,7 +677,7 @@ def _spiral_from_indexes(self, indexes: np.ndarray) -> Dict[Axis, np.ndarray]:

def calculate(self, bounds=True, nested=False) -> List[Frames[Axis]]:
return _dimensions_from_indexes(
self._spiral_from_indexes, self.axes(), self.num, bounds
self._spiral_from_indexes, [self.y_axis, self.x_axis], self.num, bounds
callumforrester marked this conversation as resolved.
Show resolved Hide resolved
)

@classmethod
Expand Down
Loading
Loading