Skip to content

Commit 7a46ac2

Browse files
committed
Numba Scan: zero out unwritten buffers
1 parent 6dd8238 commit 7a46ac2

File tree

2 files changed

+35
-1
lines changed

2 files changed

+35
-1
lines changed

pytensor/link/numba/dispatch/scan.py

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -254,6 +254,17 @@ def add_output_storage_post_proc_stmt(
254254
"""
255255
).strip()
256256
)
257+
else:
258+
# And regular loops should zero out unused entries of the output buffer
259+
# These show up with truncated gradients of while loops
260+
output_storage_post_proc_stmts.append(
261+
dedent(
262+
f"""
263+
elif {storage_size} > (i + {max_offset}):
264+
{outer_in_name}[i + {max_offset}:] = 0
265+
"""
266+
).strip()
267+
)
257268

258269
# Special in-loop statements that create (nit-sot) storage arrays after a
259270
# single iteration is performed. This is necessary because we don't know
@@ -309,7 +320,7 @@ def add_output_storage_post_proc_stmt(
309320
)
310321

311322
if outer_in_name not in outer_in_mit_mot_names:
312-
# MIT-SOT and NIT-SOT may require buffer rolling/truncation after the main loop
323+
# MIT-SOT and NIT-SOT may require buffer rolling/truncation/zeroing after the main loop
313324
max_offset_out_tap = max(output_taps) + max_lookback_inp_tap
314325
add_output_storage_post_proc_stmt(
315326
storage_name, max_offset_out_tap, storage_size_name

tests/link/numba/test_scan.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -673,3 +673,26 @@ def test_higher_order_derivatives():
673673
[g, gg, ggg],
674674
[np.array(0.95)],
675675
)
676+
677+
678+
def test_grad_until_and_truncate_sequence_taps():
679+
# This is a case where we need special zero out behavior in Scan
680+
# Copied from tests.scan.basic.py::TestGradUntil::test_grad_until_and_truncate_sequence_taps
681+
x = pt.vector("x")
682+
threshold = pt.scalar(name="threshold", dtype="int64")
683+
684+
r = scan(
685+
lambda x, y, u: (x * y, until(y > u)),
686+
sequences=dict(input=x, taps=[-2, 0]),
687+
outputs_info=[None],
688+
non_sequences=[threshold],
689+
truncate_gradient=3,
690+
return_updates=False,
691+
)
692+
g = grad(r.sum(), x)
693+
694+
compare_numba_and_py(
695+
[x, threshold],
696+
[r, g],
697+
[np.arange(15, dtype=x.dtype), 6],
698+
)

0 commit comments

Comments
 (0)