Skip to content

Commit db73461

Browse files
committed
Change behavior of helper set/inc to act on an indexed variable directly
1 parent c40d692 commit db73461

File tree

2 files changed

+22
-12
lines changed

2 files changed

+22
-12
lines changed

pytensor/tensor/variable.py

+20-10
Original file line numberDiff line numberDiff line change
@@ -824,36 +824,46 @@ def compress(self, a, axis=None):
824824
"""Return selected slices only."""
825825
return pt.extra_ops.compress(self, a, axis=axis)
826826

827-
def set(self, idx, y, **kwargs):
828-
"""Return a copy of self with the indexed values set to y.
827+
def set(self, y, **kwargs):
828+
"""Return a copy of the variable indexed by self with the indexed values set to y.
829829
830-
Equivalent to set_subtensor(self[idx], y). See docstrings for kwargs.
830+
Equivalent to set_subtensor(self, y). See docstrings for kwargs.
831+
832+
Raises
833+
------
834+
TypeError:
835+
If self is not the result of a subtensor operation
831836
832837
Examples
833838
--------
834839
>>> import pytensor.tensor as pt
835840
>>>
836841
>>> x = pt.ones((3,))
837-
>>> out = x.set(1, 2)
842+
>>> out = x[1].set(2)
838843
>>> out.eval() # array([1., 2., 1.])
839844
"""
840-
return pt.subtensor.set_subtensor(self[idx], y, **kwargs)
845+
return pt.subtensor.set_subtensor(self, y, **kwargs)
846+
847+
def inc(self, y, **kwargs):
848+
"""Return a copy of the variable indexed by self with the indexed values incremented by y.
841849
842-
def inc(self, idx, y, **kwargs):
843-
"""Return a copy of self with the indexed values incremented by y.
850+
Equivalent to inc_subtensor(self, y). See docstrings for kwargs.
844851
845-
Equivalent to inc_subtensor(self[idx], y). See docstrings for kwargs.
852+
Raises
853+
------
854+
TypeError:
855+
If self is not the result of a subtensor operation
846856
847857
Examples
848858
--------
849859
850860
>>> import pytensor.tensor as pt
851861
>>>
852862
>>> x = pt.ones((3,))
853-
>>> out = x.inc(1, 2)
863+
>>> out = x[1].inc(2)
854864
>>> out.eval() # array([1., 3., 1.])
855865
"""
856-
return pt.inc_subtensor(self[idx], y, **kwargs)
866+
return pt.inc_subtensor(self, y, **kwargs)
857867

858868

859869
class TensorVariable(

tests/tensor/test_variable.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -438,8 +438,8 @@ def test_set_inc(self):
438438
idx = [0]
439439
y = 5
440440

441-
assert equal_computations([x.set(idx, y)], [set_subtensor(x[idx], y)])
442-
assert equal_computations([x.inc(idx, y)], [inc_subtensor(x[idx], y)])
441+
assert equal_computations([x[:, idx].set(y)], [set_subtensor(x[:, idx], y)])
442+
assert equal_computations([x[:, idx].inc(y)], [inc_subtensor(x[:, idx], y)])
443443

444444
def test_set_item_error(self):
445445
x = matrix("x")

0 commit comments

Comments
 (0)