Skip to content

Commit 6dd8238

Browse files
committed
Numba Scan: correct handling of signed mitmot taps
Unlike MIT-SOT and SIT-SOT these can be positive or negative, depending on the order of differentiation
1 parent 42e8490 commit 6dd8238

File tree

2 files changed

+60
-43
lines changed

2 files changed

+60
-43
lines changed

pytensor/link/numba/dispatch/scan.py

Lines changed: 39 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -27,9 +27,8 @@ def idx_to_str(
2727
idx_symbol: str = "i",
2828
allow_scalar=False,
2929
) -> str:
30-
if offset < 0:
31-
indices = f"{idx_symbol} + {array_name}.shape[0] - {offset}"
32-
elif offset > 0:
30+
assert offset >= 0
31+
if offset > 0:
3332
indices = f"{idx_symbol} + {offset}"
3433
else:
3534
indices = idx_symbol
@@ -226,33 +225,16 @@ def add_inner_in_expr(
226225
# storage array like a circular buffer, and that's why we need to track the
227226
# storage size along with the taps length/indexing offset.
228227
def add_output_storage_post_proc_stmt(
229-
outer_in_name: str, tap_sizes: tuple[int, ...], storage_size: str
228+
outer_in_name: str, max_offset: int, storage_size: str
230229
):
231-
tap_size = max(tap_sizes)
232-
233-
if op.info.as_while:
234-
# While loops need to truncate the output storage to a length given
235-
# by the number of iterations performed.
236-
output_storage_post_proc_stmts.append(
237-
dedent(
238-
f"""
239-
if i + {tap_size} < {storage_size}:
240-
{storage_size} = i + {tap_size}
241-
{outer_in_name} = {outer_in_name}[:{storage_size}]
242-
"""
243-
).strip()
244-
)
245-
246-
# Rotate the storage so that the last computed value is at the end of
247-
# the storage array.
230+
# Rotate the storage so that the last computed value is at the end of the storage array.
248231
# This is needed when the output storage array does not have a length
249232
# equal to the number of taps plus `n_steps`.
250-
# If the storage size only allows one entry, there's nothing to rotate
251233
output_storage_post_proc_stmts.append(
252234
dedent(
253235
f"""
254-
if 1 < {storage_size} < (i + {tap_size}):
255-
{outer_in_name}_shift = (i + {tap_size}) % ({storage_size})
236+
if 1 < {storage_size} < (i + {max_offset}):
237+
{outer_in_name}_shift = (i + {max_offset}) % ({storage_size})
256238
if {outer_in_name}_shift > 0:
257239
{outer_in_name}_left = {outer_in_name}[:{outer_in_name}_shift]
258240
{outer_in_name}_right = {outer_in_name}[{outer_in_name}_shift:]
@@ -261,6 +243,18 @@ def add_output_storage_post_proc_stmt(
261243
).strip()
262244
)
263245

246+
if op.info.as_while:
247+
# While loops need to truncate the output storage to a length given
248+
# by the number of iterations performed.
249+
output_storage_post_proc_stmts.append(
250+
dedent(
251+
f"""
252+
elif {storage_size} > (i + {max_offset}):
253+
{outer_in_name} = {outer_in_name}[:i + {max_offset}]
254+
"""
255+
).strip()
256+
)
257+
264258
# Special in-loop statements that create (nit-sot) storage arrays after a
265259
# single iteration is performed. This is necessary because we don't know
266260
# the exact shapes of the storage arrays that need to be allocated until
@@ -288,12 +282,11 @@ def add_output_storage_post_proc_stmt(
288282
storage_size_name = f"{outer_in_name}_len"
289283
storage_size_stmt = f"{storage_size_name} = {outer_in_name}.shape[0]"
290284
input_taps = inner_in_names_to_input_taps[outer_in_name]
291-
tap_storage_size = -min(input_taps)
292-
assert tap_storage_size >= 0
285+
max_lookback_inp_tap = -min(input_taps)
286+
assert max_lookback_inp_tap >= 0
293287

294288
for in_tap in input_taps:
295-
tap_offset = in_tap + tap_storage_size
296-
assert tap_offset >= 0
289+
tap_offset = max_lookback_inp_tap + in_tap
297290
is_vector = outer_in_var.ndim == 1
298291
add_inner_in_expr(
299292
outer_in_name,
@@ -302,22 +295,25 @@ def add_output_storage_post_proc_stmt(
302295
vector_slice_opt=is_vector,
303296
)
304297

305-
output_taps = inner_in_names_to_output_taps.get(
306-
outer_in_name, [tap_storage_size]
307-
)
308-
inner_out_to_outer_in_stmts.extend(
309-
idx_to_str(
310-
storage_name,
311-
out_tap,
312-
size=storage_size_name,
313-
allow_scalar=True,
298+
output_taps = inner_in_names_to_output_taps.get(outer_in_name, [0])
299+
for out_tap in output_taps:
300+
tap_offset = max_lookback_inp_tap + out_tap
301+
assert tap_offset >= 0
302+
inner_out_to_outer_in_stmts.append(
303+
idx_to_str(
304+
storage_name,
305+
tap_offset,
306+
size=storage_size_name,
307+
allow_scalar=True,
308+
)
314309
)
315-
for out_tap in output_taps
316-
)
317310

318-
add_output_storage_post_proc_stmt(
319-
storage_name, output_taps, storage_size_name
320-
)
311+
if outer_in_name not in outer_in_mit_mot_names:
312+
# MIT-SOT and NIT-SOT may require buffer rolling/truncation after the main loop
313+
max_offset_out_tap = max(output_taps) + max_lookback_inp_tap
314+
add_output_storage_post_proc_stmt(
315+
storage_name, max_offset_out_tap, storage_size_name
316+
)
321317

322318
else:
323319
storage_size_stmt = ""
@@ -351,7 +347,7 @@ def add_output_storage_post_proc_stmt(
351347
inner_out_to_outer_in_stmts.append(
352348
idx_to_str(storage_name, 0, size=storage_size_name, allow_scalar=True)
353349
)
354-
add_output_storage_post_proc_stmt(storage_name, (0,), storage_size_name)
350+
add_output_storage_post_proc_stmt(storage_name, 0, storage_size_name)
355351

356352
# In case of nit-sots we are provided the length of the array in
357353
# the iteration dimension instead of actual arrays, hence we

tests/link/numba/test_scan.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -652,3 +652,24 @@ def test_mit_sot_buffer(self, constant_n_steps, n_steps_val):
652652

653653
def test_mit_sot_buffer_benchmark(self, constant_n_steps, n_steps_val, benchmark):
654654
self.buffer_tester(constant_n_steps, n_steps_val, benchmark=benchmark)
655+
656+
657+
def test_higher_order_derivatives():
658+
"""This tests different mit-mot taps signs"""
659+
x = pt.scalar("x")
660+
661+
xs = scan(
662+
fn=lambda xtm1: xtm1**2,
663+
outputs_info=[x],
664+
n_steps=5,
665+
return_updates=False,
666+
)
667+
g = grad(xs[-1], x)
668+
gg = grad(g, x)
669+
ggg = grad(gg, x)
670+
671+
compare_numba_and_py(
672+
[x],
673+
[g, gg, ggg],
674+
[np.array(0.95)],
675+
)

0 commit comments

Comments
 (0)