|
1 | 1 | import pytensor.tensor.basic as ptb
|
2 | 2 | from pytensor.scan.basic import scan
|
3 | 3 | from pytensor.tensor.basic import Join
|
4 |
| -from pytensor.tensor.math import ceil, eq |
| 4 | +from pytensor.tensor.math import ceil, eq, neq |
5 | 5 | from pytensor.tensor.subtensor import set_subtensor
|
6 | 6 |
|
7 | 7 |
|
@@ -130,16 +130,18 @@ def scan_checkpoints(
|
130 | 130 | # Since padding could be an empty tensor, Join returns a view of s.
|
131 | 131 | join = Join(view=0)
|
132 | 132 | for i, s in enumerate(sequences):
|
133 |
| - n = s.shape[0] % save_every_N |
134 |
| - z = ptb.zeros((n, s.shape[1:]), dtype=s.dtype) |
135 |
| - sequences[i] = join(0, [s, z]) |
| 133 | + overshoots_by = s.shape[0] % save_every_N |
| 134 | + overshoots = neq(overshoots_by, 0) |
| 135 | + n = (save_every_N - overshoots_by) * overshoots |
| 136 | + z = ptb.zeros((n, *s.shape[1:]), dtype=s.dtype) |
| 137 | + sequences[i] = join(0, s, z) |
136 | 138 |
|
137 | 139 | # Establish the input variables of the outer scan
|
138 | 140 | o_sequences = [
|
139 | 141 | s.reshape(
|
140 |
| - [s.shape[0] / save_every_N, save_every_N] |
| 142 | + [s.shape[0] // save_every_N, save_every_N] |
141 | 143 | + [s.shape[i] for i in range(1, s.ndim)],
|
142 |
| - s.ndim + 1, |
| 144 | + ndim=s.ndim + 1, |
143 | 145 | )
|
144 | 146 | for s in sequences
|
145 | 147 | ]
|
|
0 commit comments