forked from data-apis/array-api-tests
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathshape_helpers.py
177 lines (149 loc) · 4.82 KB
/
shape_helpers.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
import math
from itertools import product
from typing import Iterator, List, Optional, Sequence, Tuple, Union
from ndindex import iter_indices as _iter_indices
from .typing import AtomicIndex, Index, Scalar, Shape
__all__ = [
"broadcast_shapes",
"normalise_axis",
"ndindex",
"axis_ndindex",
"axes_ndindex",
"reshape",
"fmt_idx",
]
class BroadcastError(ValueError):
"""Shapes do not broadcast with eachother"""
def _broadcast_shapes(shape1: Shape, shape2: Shape) -> Shape:
"""Broadcasts `shape1` and `shape2`"""
N1 = len(shape1)
N2 = len(shape2)
N = max(N1, N2)
shape = [None for _ in range(N)]
i = N - 1
while i >= 0:
n1 = N1 - N + i
if N1 - N + i >= 0:
d1 = shape1[n1]
else:
d1 = 1
n2 = N2 - N + i
if N2 - N + i >= 0:
d2 = shape2[n2]
else:
d2 = 1
if d1 == 1:
shape[i] = d2
elif d2 == 1:
shape[i] = d1
elif d1 == d2:
shape[i] = d1
else:
raise BroadcastError()
i = i - 1
return tuple(shape)
def broadcast_shapes(*shapes: Shape):
if len(shapes) == 0:
raise ValueError("shapes=[] must be non-empty")
elif len(shapes) == 1:
return shapes[0]
result = _broadcast_shapes(shapes[0], shapes[1])
for i in range(2, len(shapes)):
result = _broadcast_shapes(result, shapes[i])
return result
def normalise_axis(
axis: Optional[Union[int, Sequence[int]]], ndim: int
) -> Tuple[int, ...]:
if axis is None:
return tuple(range(ndim))
elif isinstance(axis, Sequence) and not isinstance(axis, tuple):
axis = tuple(axis)
axes = axis if isinstance(axis, tuple) else (axis,)
axes = tuple(axis if axis >= 0 else ndim + axis for axis in axes)
return axes
def ndindex(shape: Shape) -> Iterator[Index]:
"""Yield every index of a shape"""
return (indices[0] for indices in iter_indices(shape))
def iter_indices(
*shapes: Shape, skip_axes: Tuple[int, ...] = ()
) -> Iterator[Tuple[Index, ...]]:
"""Wrapper for ndindex.iter_indices()"""
# Prevent iterations if any shape has 0-sides
for shape in shapes:
if 0 in shape:
return
for indices in _iter_indices(*shapes, skip_axes=skip_axes):
yield tuple(i.raw for i in indices) # type: ignore
def axis_ndindex(
shape: Shape, axis: int
) -> Iterator[Tuple[Tuple[Union[int, slice], ...], ...]]:
"""Generate indices that index all elements in dimensions beyond `axis`"""
assert axis >= 0 # sanity check
axis_indices = [range(side) for side in shape[:axis]]
for _ in range(axis, len(shape)):
axis_indices.append([slice(None, None)])
yield from product(*axis_indices)
def axes_ndindex(shape: Shape, axes: Tuple[int, ...]) -> Iterator[List[Shape]]:
"""Generate indices that index all elements except in `axes` dimensions"""
base_indices = []
axes_indices = []
for axis, side in enumerate(shape):
if axis in axes:
base_indices.append([None])
axes_indices.append(range(side))
else:
base_indices.append(range(side))
axes_indices.append([None])
for base_idx in product(*base_indices):
indices = []
for idx in product(*axes_indices):
idx = list(idx)
for axis, side in enumerate(idx):
if axis not in axes:
idx[axis] = base_idx[axis]
idx = tuple(idx)
indices.append(idx)
yield list(indices)
def reshape(flat_seq: List[Scalar], shape: Shape) -> Union[Scalar, List]:
"""Reshape a flat sequence"""
if any(s == 0 for s in shape):
raise ValueError(
f"{shape=} contains 0-sided dimensions, "
f"but that's not representable in lists"
)
if len(shape) == 0:
assert len(flat_seq) == 1 # sanity check
return flat_seq[0]
elif len(shape) == 1:
return flat_seq
size = len(flat_seq)
n = math.prod(shape[1:])
return [reshape(flat_seq[i * n : (i + 1) * n], shape[1:]) for i in range(size // n)]
def fmt_i(i: AtomicIndex) -> str:
if isinstance(i, int):
return str(i)
elif isinstance(i, slice):
res = ""
if i.start is not None:
res += str(i.start)
res += ":"
if i.stop is not None:
res += str(i.stop)
if i.step is not None:
res += f":{i.step}"
return res
elif i is None:
return "None"
else:
return "..."
def fmt_idx(sym: str, idx: Index) -> str:
if idx == ():
return sym
res = f"{sym}["
_idx = idx if isinstance(idx, tuple) else (idx,)
if len(_idx) == 1:
res += fmt_i(_idx[0])
else:
res += ", ".join(fmt_i(i) for i in _idx)
res += "]"
return res