Skip to content

Commit 4b41e09

Browse files
ArmavicaricardoV94
authored andcommitted
Add exceptions for hot loops
1 parent 54fba94 commit 4b41e09

File tree

19 files changed

+77
-48
lines changed

19 files changed

+77
-48
lines changed

pytensor/compile/builders.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -863,5 +863,6 @@ def clone(self):
863863
def perform(self, node, inputs, outputs):
864864
variables = self.fn(*inputs)
865865
assert len(variables) == len(outputs)
866-
for output, variable in zip(outputs, variables, strict=True):
866+
# strict=False because asserted above
867+
for output, variable in zip(outputs, variables, strict=False):
867868
output[0] = variable

pytensor/compile/function/types.py

+8-4
Original file line numberDiff line numberDiff line change
@@ -1002,8 +1002,9 @@ def __call__(self, *args, **kwargs):
10021002
# if we are allowing garbage collection, remove the
10031003
# output reference from the internal storage cells
10041004
if getattr(self.vm, "allow_gc", False):
1005+
# strict=False because we are in a hot loop
10051006
for o_container, o_variable in zip(
1006-
self.output_storage, self.maker.fgraph.outputs, strict=True
1007+
self.output_storage, self.maker.fgraph.outputs, strict=False
10071008
):
10081009
if o_variable.owner is not None:
10091010
# this node is the variable of computation
@@ -1012,8 +1013,9 @@ def __call__(self, *args, **kwargs):
10121013

10131014
if getattr(self.vm, "need_update_inputs", True):
10141015
# Update the inputs that have an update function
1016+
# strict=False because we are in a hot loop
10151017
for input, storage in reversed(
1016-
list(zip(self.maker.expanded_inputs, input_storage, strict=True))
1018+
list(zip(self.maker.expanded_inputs, input_storage, strict=False))
10171019
):
10181020
if input.update is not None:
10191021
storage.data = outputs.pop()
@@ -1044,7 +1046,8 @@ def __call__(self, *args, **kwargs):
10441046
assert len(self.output_keys) == len(outputs)
10451047

10461048
if output_subset is None:
1047-
return dict(zip(self.output_keys, outputs, strict=True))
1049+
# strict=False because we are in a hot loop
1050+
return dict(zip(self.output_keys, outputs, strict=False))
10481051
else:
10491052
return {
10501053
self.output_keys[index]: outputs[index]
@@ -1111,8 +1114,9 @@ def _pickle_Function(f):
11111114
ins = list(f.input_storage)
11121115
input_storage = []
11131116

1117+
# strict=False because we are in a hot loop
11141118
for (input, indices, inputs), (required, refeed, default) in zip(
1115-
f.indices, f.defaults, strict=True
1119+
f.indices, f.defaults, strict=False
11161120
):
11171121
input_storage.append(ins[0])
11181122
del ins[0]

pytensor/ifelse.py

+4-2
Original file line numberDiff line numberDiff line change
@@ -305,7 +305,8 @@ def thunk():
305305
if len(ls) > 0:
306306
return ls
307307
else:
308-
for out, t in zip(outputs, input_true_branch, strict=True):
308+
# strict=False because we are in a hot loop
309+
for out, t in zip(outputs, input_true_branch, strict=False):
309310
compute_map[out][0] = 1
310311
val = storage_map[t][0]
311312
if self.as_view:
@@ -325,7 +326,8 @@ def thunk():
325326
if len(ls) > 0:
326327
return ls
327328
else:
328-
for out, f in zip(outputs, inputs_false_branch, strict=True):
329+
# strict=False because we are in a hot loop
330+
for out, f in zip(outputs, inputs_false_branch, strict=False):
329331
compute_map[out][0] = 1
330332
# can't view both outputs unless destroyhandler
331333
# improves

pytensor/link/basic.py

+6-3
Original file line numberDiff line numberDiff line change
@@ -539,12 +539,14 @@ def make_thunk(self, **kwargs):
539539

540540
def f():
541541
for inputs in input_lists[1:]:
542-
for input1, input2 in zip(inputs0, inputs, strict=True):
542+
# strict=False because we are in a hot loop
543+
for input1, input2 in zip(inputs0, inputs, strict=False):
543544
input2.storage[0] = copy(input1.storage[0])
544545
for x in to_reset:
545546
x[0] = None
546547
pre(self, [input.data for input in input_lists[0]], order, thunk_groups)
547-
for i, (thunks, node) in enumerate(zip(thunk_groups, order, strict=True)):
548+
# strict=False because we are in a hot loop
549+
for i, (thunks, node) in enumerate(zip(thunk_groups, order, strict=False)):
548550
try:
549551
wrapper(self.fgraph, i, node, *thunks)
550552
except Exception:
@@ -666,8 +668,9 @@ def thunk(
666668
):
667669
outputs = fgraph_jit(*[self.input_filter(x[0]) for x in thunk_inputs])
668670

671+
# strict=False because we are in a hot loop
669672
for o_var, o_storage, o_val in zip(
670-
fgraph.outputs, thunk_outputs, outputs, strict=True
673+
fgraph.outputs, thunk_outputs, outputs, strict=False
671674
):
672675
compute_map[o_var][0] = True
673676
o_storage[0] = self.output_filter(o_var, o_val)

pytensor/link/c/basic.py

+6-5
Original file line numberDiff line numberDiff line change
@@ -1993,25 +1993,26 @@ def make_thunk(self, **kwargs):
19931993
)
19941994

19951995
def f():
1996-
for input1, input2 in zip(i1, i2, strict=True):
1996+
# strict=False because we are in a hot loop
1997+
for input1, input2 in zip(i1, i2, strict=False):
19971998
# Set the inputs to be the same in both branches.
19981999
# The copy is necessary in order for inplace ops not to
19992000
# interfere.
20002001
input2.storage[0] = copy(input1.storage[0])
20012002
for thunk1, thunk2, node1, node2 in zip(
2002-
thunks1, thunks2, order1, order2, strict=True
2003+
thunks1, thunks2, order1, order2, strict=False
20032004
):
2004-
for output, storage in zip(node1.outputs, thunk1.outputs, strict=True):
2005+
for output, storage in zip(node1.outputs, thunk1.outputs, strict=False):
20052006
if output in no_recycling:
20062007
storage[0] = None
2007-
for output, storage in zip(node2.outputs, thunk2.outputs, strict=True):
2008+
for output, storage in zip(node2.outputs, thunk2.outputs, strict=False):
20082009
if output in no_recycling:
20092010
storage[0] = None
20102011
try:
20112012
thunk1()
20122013
thunk2()
20132014
for output1, output2 in zip(
2014-
thunk1.outputs, thunk2.outputs, strict=True
2015+
thunk1.outputs, thunk2.outputs, strict=False
20152016
):
20162017
self.checker(output1, output2)
20172018
except Exception:

pytensor/link/numba/dispatch/basic.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -401,9 +401,10 @@ def py_perform_return(inputs):
401401
else:
402402

403403
def py_perform_return(inputs):
404+
# strict=False because we are in a hot loop
404405
return tuple(
405406
out_type.filter(out[0])
406-
for out_type, out in zip(output_types, py_perform(inputs), strict=True)
407+
for out_type, out in zip(output_types, py_perform(inputs), strict=False)
407408
)
408409

409410
@numba_njit

pytensor/link/pytorch/dispatch/shape.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,8 @@ def shape_i(x):
3434
def pytorch_funcify_SpecifyShape(op, node, **kwargs):
3535
def specifyshape(x, *shape):
3636
assert x.ndim == len(shape)
37-
for actual, expected in zip(x.shape, shape, strict=True):
37+
# strict=False because asserted above
38+
for actual, expected in zip(x.shape, shape, strict=False):
3839
if expected is None:
3940
continue
4041
if actual != expected:

pytensor/link/utils.py

+4-2
Original file line numberDiff line numberDiff line change
@@ -190,8 +190,9 @@ def streamline_default_f():
190190
for x in no_recycling:
191191
x[0] = None
192192
try:
193+
# strict=False because we are in a hot loop
193194
for thunk, node, old_storage in zip(
194-
thunks, order, post_thunk_old_storage, strict=True
195+
thunks, order, post_thunk_old_storage, strict=False
195196
):
196197
thunk()
197198
for old_s in old_storage:
@@ -206,7 +207,8 @@ def streamline_nice_errors_f():
206207
for x in no_recycling:
207208
x[0] = None
208209
try:
209-
for thunk, node in zip(thunks, order, strict=True):
210+
# strict=False because we are in a hot loop
211+
for thunk, node in zip(thunks, order, strict=False):
210212
thunk()
211213
except Exception:
212214
raise_with_op(fgraph, node, thunk)

pytensor/scalar/basic.py

+4-2
Original file line numberDiff line numberDiff line change
@@ -1150,8 +1150,9 @@ def perform(self, node, inputs, output_storage):
11501150
else:
11511151
variables = from_return_values(self.impl(*inputs))
11521152
assert len(variables) == len(output_storage)
1153+
# strict=False because we are in a hot loop
11531154
for out, storage, variable in zip(
1154-
node.outputs, output_storage, variables, strict=True
1155+
node.outputs, output_storage, variables, strict=False
11551156
):
11561157
dtype = out.dtype
11571158
storage[0] = self._cast_scalar(variable, dtype)
@@ -4328,7 +4329,8 @@ def make_node(self, *inputs):
43284329

43294330
def perform(self, node, inputs, output_storage):
43304331
outputs = self.py_perform_fn(*inputs)
4331-
for storage, out_val in zip(output_storage, outputs, strict=True):
4332+
# strict=False because we are in a hot loop
4333+
for storage, out_val in zip(output_storage, outputs, strict=False):
43324334
storage[0] = out_val
43334335

43344336
def grad(self, inputs, output_grads):

pytensor/scalar/loop.py

+3-2
Original file line numberDiff line numberDiff line change
@@ -93,7 +93,7 @@ def _validate_updates(
9393
)
9494
else:
9595
update = outputs
96-
for i, u in zip(init[: len(update)], update, strict=True):
96+
for i, u in zip(init, update, strict=False):
9797
if i.type != u.type:
9898
raise TypeError(
9999
"Init and update types must be the same: "
@@ -207,7 +207,8 @@ def perform(self, node, inputs, output_storage):
207207
for i in range(n_steps):
208208
carry = inner_fn(*carry, *constant)
209209

210-
for storage, out_val in zip(output_storage, carry, strict=True):
210+
# strict=False because we are in a hot loop
211+
for storage, out_val in zip(output_storage, carry, strict=False):
211212
storage[0] = out_val
212213

213214
@property

pytensor/scan/op.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -1278,8 +1278,9 @@ def __eq__(self, other):
12781278
if len(self.inner_outputs) != len(other.inner_outputs):
12791279
return False
12801280

1281+
# strict=False because length already compared above
12811282
for self_in, other_in in zip(
1282-
self.inner_inputs, other.inner_inputs, strict=True
1283+
self.inner_inputs, other.inner_inputs, strict=False
12831284
):
12841285
if self_in.type != other_in.type:
12851286
return False

pytensor/tensor/basic.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -3463,7 +3463,8 @@ def perform(self, node, inp, out):
34633463

34643464
# Make sure the output is big enough
34653465
out_s = []
3466-
for xdim, ydim in zip(x_s, y_s, strict=True):
3466+
# strict=False because we are in a hot loop
3467+
for xdim, ydim in zip(x_s, y_s, strict=False):
34673468
if xdim == ydim:
34683469
outdim = xdim
34693470
elif xdim == 1:

pytensor/tensor/blockwise.py

+6-4
Original file line numberDiff line numberDiff line change
@@ -342,16 +342,17 @@ def core_func(
342342
def _check_runtime_broadcast(self, node, inputs):
343343
batch_ndim = self.batch_ndim(node)
344344

345+
# strict=False because we are in a hot loop
345346
for dims_and_bcast in zip(
346347
*[
347348
zip(
348349
input.shape[:batch_ndim],
349350
sinput.type.broadcastable[:batch_ndim],
350-
strict=True,
351+
strict=False,
351352
)
352-
for input, sinput in zip(inputs, node.inputs, strict=True)
353+
for input, sinput in zip(inputs, node.inputs, strict=False)
353354
],
354-
strict=True,
355+
strict=False,
355356
):
356357
if any(d != 1 for d, _ in dims_and_bcast) and (1, False) in dims_and_bcast:
357358
raise ValueError(
@@ -374,8 +375,9 @@ def perform(self, node, inputs, output_storage):
374375
if not isinstance(res, tuple):
375376
res = (res,)
376377

378+
# strict=False because we are in a hot loop
377379
for node_out, out_storage, r in zip(
378-
node.outputs, output_storage, res, strict=True
380+
node.outputs, output_storage, res, strict=False
379381
):
380382
out_dtype = getattr(node_out, "dtype", None)
381383
if out_dtype and out_dtype != r.dtype:

pytensor/tensor/elemwise.py

+5-3
Original file line numberDiff line numberDiff line change
@@ -737,8 +737,9 @@ def perform(self, node, inputs, output_storage):
737737
if nout == 1:
738738
variables = [variables]
739739

740+
# strict=False because we are in a hot loop
740741
for i, (variable, storage, nout) in enumerate(
741-
zip(variables, output_storage, node.outputs, strict=True)
742+
zip(variables, output_storage, node.outputs, strict=False)
742743
):
743744
storage[0] = variable = np.asarray(variable, dtype=nout.dtype)
744745

@@ -753,12 +754,13 @@ def perform(self, node, inputs, output_storage):
753754

754755
@staticmethod
755756
def _check_runtime_broadcast(node, inputs):
757+
# strict=False because we are in a hot loop
756758
for dims_and_bcast in zip(
757759
*[
758760
zip(input.shape, sinput.type.broadcastable, strict=False)
759-
for input, sinput in zip(inputs, node.inputs, strict=True)
761+
for input, sinput in zip(inputs, node.inputs, strict=False)
760762
],
761-
strict=True,
763+
strict=False,
762764
):
763765
if any(d != 1 for d, _ in dims_and_bcast) and (1, False) in dims_and_bcast:
764766
raise ValueError(

pytensor/tensor/random/basic.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -1862,7 +1862,8 @@ def rng_fn(cls, rng, p, size):
18621862
# to `p.shape[:-1]` in the call to `vsearchsorted` below.
18631863
if len(size) < (p.ndim - 1):
18641864
raise ValueError("`size` is incompatible with the shape of `p`")
1865-
for s, ps in zip(reversed(size), reversed(p.shape[:-1]), strict=True):
1865+
# strict=False because we are in a hot loop
1866+
for s, ps in zip(reversed(size), reversed(p.shape[:-1]), strict=False):
18661867
if s == 1 and ps != 1:
18671868
raise ValueError("`size` is incompatible with the shape of `p`")
18681869

pytensor/tensor/random/utils.py

+8-4
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,8 @@ def params_broadcast_shapes(
4444
max_fn = maximum if use_pytensor else max
4545

4646
rev_extra_dims: list[int] = []
47-
for ndim_param, param_shape in zip(ndims_params, param_shapes, strict=True):
47+
# strict=False because we are in a hot loop
48+
for ndim_param, param_shape in zip(ndims_params, param_shapes, strict=False):
4849
# We need this in order to use `len`
4950
param_shape = tuple(param_shape)
5051
extras = tuple(param_shape[: (len(param_shape) - ndim_param)])
@@ -63,11 +64,12 @@ def max_bcast(x, y):
6364

6465
extra_dims = tuple(reversed(rev_extra_dims))
6566

67+
# strict=False because we are in a hot loop
6668
bcast_shapes = [
6769
(extra_dims + tuple(param_shape)[-ndim_param:])
6870
if ndim_param > 0
6971
else extra_dims
70-
for ndim_param, param_shape in zip(ndims_params, param_shapes, strict=True)
72+
for ndim_param, param_shape in zip(ndims_params, param_shapes, strict=False)
7173
]
7274

7375
return bcast_shapes
@@ -110,10 +112,11 @@ def broadcast_params(
110112
use_pytensor = False
111113
param_shapes = []
112114
for p in params:
115+
# strict=False because we are in a hot loop
113116
param_shape = tuple(
114117
1 if bcast else s
115118
for s, bcast in zip(
116-
p.shape, getattr(p, "broadcastable", (False,) * p.ndim), strict=True
119+
p.shape, getattr(p, "broadcastable", (False,) * p.ndim), strict=False
117120
)
118121
)
119122
use_pytensor |= isinstance(p, Variable)
@@ -124,9 +127,10 @@ def broadcast_params(
124127
)
125128
broadcast_to_fn = broadcast_to if use_pytensor else np.broadcast_to
126129

130+
# strict=False because we are in a hot loop
127131
bcast_params = [
128132
broadcast_to_fn(param, shape)
129-
for shape, param in zip(shapes, params, strict=True)
133+
for shape, param in zip(shapes, params, strict=False)
130134
]
131135

132136
return bcast_params

pytensor/tensor/rewriting/subtensor.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -683,7 +683,7 @@ def local_subtensor_of_alloc(fgraph, node):
683683
# Slices to take from val
684684
val_slices = []
685685

686-
for i, (sl, dim) in enumerate(zip(slices, dims[: len(slices)], strict=True)):
686+
for i, (sl, dim) in enumerate(zip(slices, dims, strict=False)):
687687
# If val was not copied over that dim,
688688
# we need to take the appropriate subtensor on it.
689689
if i >= n_added_dims:

pytensor/tensor/shape.py

+6-8
Original file line numberDiff line numberDiff line change
@@ -448,8 +448,9 @@ def perform(self, node, inp, out_):
448448
raise AssertionError(
449449
f"SpecifyShape: Got {x.ndim} dimensions (shape {x.shape}), expected {ndim} dimensions with shape {tuple(shape)}."
450450
)
451+
# strict=False because we are in a hot loop
451452
if not all(
452-
xs == s for xs, s in zip(x.shape, shape, strict=True) if s is not None
453+
xs == s for xs, s in zip(x.shape, shape, strict=False) if s is not None
453454
):
454455
raise AssertionError(
455456
f"SpecifyShape: Got shape {x.shape}, expected {tuple(int(s) if s is not None else None for s in shape)}."
@@ -578,15 +579,12 @@ def specify_shape(
578579
x = ptb.as_tensor_variable(x) # type: ignore[arg-type,unused-ignore]
579580
# The above is a type error in Python 3.9 but not 3.12.
580581
# Thus we need to ignore unused-ignore on 3.12.
582+
new_shape_info = any(
583+
s != xts for (s, xts) in zip(shape, x.type.shape, strict=False) if s is not None
584+
)
581585

582586
# If shape does not match x.ndim, we rely on the `Op` to raise a ValueError
583-
if len(shape) != x.type.ndim:
584-
return _specify_shape(x, *shape)
585-
586-
new_shape_matches = all(
587-
s == xts for (s, xts) in zip(shape, x.type.shape, strict=True) if s is not None
588-
)
589-
if new_shape_matches:
587+
if not new_shape_info and len(shape) == x.type.ndim:
590588
return x
591589

592590
return _specify_shape(x, *shape)

0 commit comments

Comments
 (0)