Skip to content

Commit 870b900

Browse files
committed
Refactor _gufunc_to_out_shape for giving priority to Constant dimensions along with error handling
1 parent 7931668 commit 870b900

File tree

1 file changed

+38
-14
lines changed

1 file changed

+38
-14
lines changed

pytensor/tensor/utils.py

+38-14
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
import pytensor
88
from pytensor.graph import FunctionGraph, Variable
99
from pytensor.npy_2_compat import normalize_axis_tuple
10+
from pytensor.tensor import Any, Constant
1011
from pytensor.utils import hash_from_code
1112

1213

@@ -203,8 +204,8 @@ def _parse_gufunc_signature(
203204

204205

205206
def _gufunc_to_out_shape(
206-
signature: str, shapes: list[tuple[int, ...]]
207-
) -> list[tuple[int, ...]]:
207+
signature: str, shapes: list[tuple[Any, ...]]
208+
) -> list[tuple[Any, ...]]:
208209
"""
209210
Compute the shape of the output of an Op given its gufunc signature and the
210211
shapes of its inputs.
@@ -215,24 +216,47 @@ def _gufunc_to_out_shape(
215216
The gufunc signature of the Op.
216217
eg: "(m,n),(n,p)->(m,p)".
217218
218-
shapes : list of tuple of int
219+
shapes : list of tuple of Any
219220
The list of shapes of the inputs.
220221
221222
Returns
222223
-------
223-
out_shape : list of tuple of int
224+
out_shape : list of tuple of Any
224225
The list of shapes of the outputs.
226+
227+
Raises
228+
------
229+
ValueError
230+
If the signature is invalid for the shapes of the inputs.
225231
"""
226-
parsed = _parse_gufunc_signature(signature)
227-
out_shape = []
228-
dic = dict()
229-
for i in range(len(parsed[0])):
230-
for j in range(len(parsed[0][i])):
231-
dic[parsed[0][i][j]] = shapes[i][j]
232-
for i in range(len(parsed[1])):
233-
temp_list = [dic[x] for x in parsed[1][i]]
234-
out_shape.append(tuple(temp_list))
235-
return out_shape
232+
input_sig, output_sig = _parse_gufunc_signature(signature)
233+
dim_to_size: dict[str, Any] = {}
234+
for input_shape, sig in zip(shapes, input_sig, strict=True):
235+
for size, dim_name in zip(input_shape, sig, strict=True):
236+
prev_size = dim_to_size.get(dim_name)
237+
if prev_size is None:
238+
dim_to_size[dim_name] = size
239+
# Prefer constants
240+
elif not isinstance(prev_size, Constant):
241+
dim_to_size[dim_name] = size
242+
elif prev_size.data != size:
243+
raise ValueError(
244+
f"Invalid signature {signature} for shapes {shapes}. "
245+
f"Dimension {dim_name} is not consistent across inputs."
246+
)
247+
out_shapes = []
248+
for output_shape in output_sig:
249+
temp_list = []
250+
for dim in output_shape:
251+
if dim not in dim_to_size:
252+
raise ValueError(
253+
f"Invalid signature {signature} for shapes {shapes}. "
254+
f"Dimension {dim} not in input dimensions."
255+
)
256+
else:
257+
temp_list.append(dim_to_size[dim])
258+
out_shapes.append((*temp_list,))
259+
return out_shapes
236260

237261

238262
def safe_signature(

0 commit comments

Comments
 (0)