Skip to content

Commit 5a61625

Browse files
committed
Fix scan_checkpoints with padded sequences
1 parent db73461 commit 5a61625

File tree

2 files changed

+14
-9
lines changed

2 files changed

+14
-9
lines changed

pytensor/scan/checkpoints.py

+8-6
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
import pytensor.tensor.basic as ptb
22
from pytensor.scan.basic import scan
33
from pytensor.tensor.basic import Join
4-
from pytensor.tensor.math import ceil, eq
4+
from pytensor.tensor.math import ceil, eq, neq
55
from pytensor.tensor.subtensor import set_subtensor
66

77

@@ -130,16 +130,18 @@ def scan_checkpoints(
130130
# Since padding could be an empty tensor, Join returns a view of s.
131131
join = Join(view=0)
132132
for i, s in enumerate(sequences):
133-
n = s.shape[0] % save_every_N
134-
z = ptb.zeros((n, s.shape[1:]), dtype=s.dtype)
135-
sequences[i] = join(0, [s, z])
133+
overshoots_by = s.shape[0] % save_every_N
134+
overshoots = neq(overshoots_by, 0)
135+
n = (save_every_N - overshoots_by) * overshoots
136+
z = ptb.zeros((n, *s.shape[1:]), dtype=s.dtype)
137+
sequences[i] = join(0, s, z)
136138

137139
# Establish the input variables of the outer scan
138140
o_sequences = [
139141
s.reshape(
140-
[s.shape[0] / save_every_N, save_every_N]
142+
[s.shape[0] // save_every_N, save_every_N]
141143
+ [s.shape[i] for i in range(1, s.ndim)],
142-
s.ndim + 1,
144+
ndim=s.ndim + 1,
143145
)
144146
for s in sequences
145147
]

tests/scan/test_checkpoints.py

+6-3
Original file line numberDiff line numberDiff line change
@@ -5,23 +5,26 @@
55
from pytensor.gradient import grad
66
from pytensor.scan.basic import scan
77
from pytensor.scan.checkpoints import scan_checkpoints
8-
from pytensor.tensor.basic import ones_like
8+
from pytensor.tensor.basic import arange, ones_like
99
from pytensor.tensor.type import iscalar, vector
1010

1111

1212
class TestScanCheckpoint:
1313
def setup_method(self):
1414
self.k = iscalar("k")
1515
self.A = vector("A")
16+
seq = arange(self.k, dtype="float32") + 1
1617
result, _ = scan(
17-
fn=lambda prior_result, A: prior_result * A,
18+
fn=lambda s, prior_result, A: prior_result * A / s,
1819
outputs_info=ones_like(self.A),
20+
sequences=[seq],
1921
non_sequences=self.A,
2022
n_steps=self.k,
2123
)
2224
result_check, _ = scan_checkpoints(
23-
fn=lambda prior_result, A: prior_result * A,
25+
fn=lambda s, prior_result, A: prior_result * A / s,
2426
outputs_info=ones_like(self.A),
27+
sequences=[seq],
2528
non_sequences=self.A,
2629
n_steps=self.k,
2730
save_every_N=100,

0 commit comments

Comments
 (0)