7
7
import pytensor
8
8
from pytensor .graph import FunctionGraph , Variable
9
9
from pytensor .npy_2_compat import normalize_axis_tuple
10
+ from pytensor .tensor import Any , Constant
10
11
from pytensor .utils import hash_from_code
11
12
12
13
@@ -203,8 +204,8 @@ def _parse_gufunc_signature(
203
204
204
205
205
206
def _gufunc_to_out_shape (
206
- signature : str , shapes : list [tuple [int , ...]]
207
- ) -> list [tuple [int , ...]]:
207
+ signature : str , shapes : list [tuple [Any , ...]]
208
+ ) -> list [tuple [Any , ...]]:
208
209
"""
209
210
Compute the shape of the output of an Op given its gufunc signature and the
210
211
shapes of its inputs.
@@ -215,24 +216,47 @@ def _gufunc_to_out_shape(
215
216
The gufunc signature of the Op.
216
217
eg: "(m,n),(n,p)->(m,p)".
217
218
218
- shapes : list of tuple of int
219
+ shapes : list of tuple of Any
219
220
The list of shapes of the inputs.
220
221
221
222
Returns
222
223
-------
223
- out_shape : list of tuple of int
224
+ out_shape : list of tuple of Any
224
225
The list of shapes of the outputs.
226
+
227
+ Raises
228
+ ------
229
+ ValueError
230
+ If the signature is invalid for the shapes of the inputs.
225
231
"""
226
- parsed = _parse_gufunc_signature (signature )
227
- out_shape = []
228
- dic = dict ()
229
- for i in range (len (parsed [0 ])):
230
- for j in range (len (parsed [0 ][i ])):
231
- dic [parsed [0 ][i ][j ]] = shapes [i ][j ]
232
- for i in range (len (parsed [1 ])):
233
- temp_list = [dic [x ] for x in parsed [1 ][i ]]
234
- out_shape .append (tuple (temp_list ))
235
- return out_shape
232
+ input_sig , output_sig = _parse_gufunc_signature (signature )
233
+ dim_to_size : dict [str , Any ] = {}
234
+ for input_shape , sig in zip (shapes , input_sig , strict = True ):
235
+ for size , dim_name in zip (input_shape , sig , strict = True ):
236
+ prev_size = dim_to_size .get (dim_name )
237
+ if prev_size is None :
238
+ dim_to_size [dim_name ] = size
239
+ # Prefer constants
240
+ elif not isinstance (prev_size , Constant ):
241
+ dim_to_size [dim_name ] = size
242
+ elif prev_size .data != size :
243
+ raise ValueError (
244
+ f"Invalid signature { signature } for shapes { shapes } . "
245
+ f"Dimension { dim_name } is not consistent across inputs."
246
+ )
247
+ out_shapes = []
248
+ for output_shape in output_sig :
249
+ temp_list = []
250
+ for dim in output_shape :
251
+ if dim not in dim_to_size :
252
+ raise ValueError (
253
+ f"Invalid signature { signature } for shapes { shapes } . "
254
+ f"Dimension { dim } not in input dimensions."
255
+ )
256
+ else :
257
+ temp_list .append (dim_to_size [dim ])
258
+ out_shapes .append ((* temp_list ,))
259
+ return out_shapes
236
260
237
261
238
262
def safe_signature (
0 commit comments