Skip to content

Commit e64cce7

Browse files
committed
Simplify linalg rewrites with pattern matching
1 parent cf91364 commit e64cce7

File tree

3 files changed

+54
-45
lines changed

3 files changed

+54
-45
lines changed

pytensor/tensor/_linalg/solve/rewriting.py

Lines changed: 34 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -92,16 +92,13 @@ def get_root_A(a: TensorVariable) -> tuple[TensorVariable, bool]:
9292
def find_solve_clients(var, assume_a):
9393
clients = []
9494
for cl, idx in fgraph.clients[var]:
95-
if (
96-
idx == 0
97-
and isinstance(cl.op, Blockwise)
98-
and isinstance(cl.op.core_op, Solve)
99-
and (cl.op.core_op.assume_a == assume_a)
100-
):
101-
clients.append(cl)
102-
elif isinstance(cl.op, DimShuffle) and cl.op.is_left_expand_dims:
103-
# If it's a left expand_dims, recurse on the output
104-
clients.extend(find_solve_clients(cl.outputs[0], assume_a))
95+
match (idx, cl.op, *cl.outputs):
96+
case (0, Blockwise(Solve(assume_a=assume_a_var)), *_) if (
97+
assume_a_var == assume_a
98+
):
99+
clients.append(cl)
100+
case (0, DimShuffle(is_left_expand_dims=True), cl_out):
101+
clients.extend(find_solve_clients(cl_out, assume_a))
105102
return clients
106103

107104
assume_a = node.op.core_op.assume_a
@@ -186,34 +183,34 @@ def _scan_split_non_sequence_decomposition_and_solve(
186183
changed = False
187184
while True:
188185
for inner_node in new_scan_fgraph.toposort():
189-
if (
190-
isinstance(inner_node.op, Blockwise)
191-
and isinstance(inner_node.op.core_op, Solve)
192-
and inner_node.op.core_op.assume_a in allowed_assume_a
193-
):
194-
A, _b = inner_node.inputs
195-
if all(
196-
(isinstance(root_inp, Constant) or (root_inp in non_sequences))
197-
for root_inp in graph_inputs([A])
186+
match (inner_node.op, *inner_node.inputs):
187+
case (Blockwise(Solve(assume_a=assume_a_var)), A, _b) if (
188+
assume_a_var in allowed_assume_a
198189
):
199-
if new_scan_fgraph is scan_op.fgraph:
200-
# Clone the first time to avoid mutating the original fgraph
201-
new_scan_fgraph, equiv = new_scan_fgraph.clone_get_equiv()
202-
non_sequences = {equiv[non_seq] for non_seq in non_sequences}
203-
inner_node = equiv[inner_node] # type: ignore
204-
205-
replace_dict = _split_decomp_and_solve_steps(
206-
new_scan_fgraph,
207-
inner_node,
208-
eager=True,
209-
allowed_assume_a=allowed_assume_a,
210-
)
211-
assert isinstance(replace_dict, dict) and len(replace_dict) > 0, (
212-
"Rewrite failed"
213-
)
214-
new_scan_fgraph.replace_all(replace_dict.items())
215-
changed = True
216-
break # Break to start over with a fresh toposort
190+
if all(
191+
(isinstance(root_inp, Constant) or (root_inp in non_sequences))
192+
for root_inp in graph_inputs([A])
193+
):
194+
if new_scan_fgraph is scan_op.fgraph:
195+
# Clone the first time to avoid mutating the original fgraph
196+
new_scan_fgraph, equiv = new_scan_fgraph.clone_get_equiv()
197+
non_sequences = {
198+
equiv[non_seq] for non_seq in non_sequences
199+
}
200+
inner_node = equiv[inner_node] # type: ignore
201+
202+
replace_dict = _split_decomp_and_solve_steps(
203+
new_scan_fgraph,
204+
inner_node,
205+
eager=True,
206+
allowed_assume_a=allowed_assume_a,
207+
)
208+
assert (
209+
isinstance(replace_dict, dict) and len(replace_dict) > 0
210+
), "Rewrite failed"
211+
new_scan_fgraph.replace_all(replace_dict.items())
212+
changed = True
213+
break # Break to start over with a fresh toposort
217214
else: # no_break
218215
break # Nothing else changed
219216

pytensor/tensor/rewriting/linalg.py

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,6 @@
22

33
import numpy as np
44

5-
from pytensor import Variable
65
from pytensor import tensor as pt
76
from pytensor.compile import optdb
87
from pytensor.graph import Apply, Constant, FunctionGraph
@@ -310,9 +309,7 @@ def log_of_prod_to_sum_of_log(fgraph, node):
310309

311310
@register_specialize
312311
@node_rewriter([blockwise_of(MatrixInverse | Cholesky | MatrixPinv)])
313-
def lift_linalg_of_expanded_matrices(
314-
fgraph: FunctionGraph, node: Apply
315-
) -> list[Variable] | None:
312+
def lift_linalg_of_expanded_matrices(fgraph: FunctionGraph, node: Apply):
316313
"""
317314
Rewrite compositions of linear algebra operations by lifting expensive operations (Cholesky, Inverse) through Ops
318315
that join matrices (KroneckerProduct, BlockDiagonal).

tests/tensor/linalg/test_rewriting.py

Lines changed: 19 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -86,13 +86,21 @@ def test_lu_decomposition_reused_forward_and_gradient(assume_a, counter, transpo
8686

8787
x = solve(A, b, assume_a=assume_a, transposed=transposed)
8888
grad_x_wrt_A = grad(x.sum(), A)
89-
fn_no_opt = function([A, b], [x, grad_x_wrt_A], mode=mode.excluding(rewrite_name))
89+
fn_no_opt = function(
90+
[A, b],
91+
[x, grad_x_wrt_A],
92+
mode=mode.excluding(rewrite_name, "psd_solve_to_chol_solve"),
93+
)
9094
no_opt_nodes = fn_no_opt.maker.fgraph.apply_nodes
9195
assert counter.count_vanilla_solve_nodes(no_opt_nodes) == 2
9296
assert counter.count_decomp_nodes(no_opt_nodes) == 0
9397
assert counter.count_solve_nodes(no_opt_nodes) == 0
9498

95-
fn_opt = function([A, b], [x, grad_x_wrt_A], mode=mode.including(rewrite_name))
99+
fn_opt = function(
100+
[A, b],
101+
[x, grad_x_wrt_A],
102+
mode=mode.including(rewrite_name).excluding("psd_solve_to_chol_solve"),
103+
)
96104
opt_nodes = fn_opt.maker.fgraph.apply_nodes
97105
assert counter.count_vanilla_solve_nodes(opt_nodes) == 0
98106
assert counter.count_decomp_nodes(opt_nodes) == 1
@@ -129,13 +137,19 @@ def test_lu_decomposition_reused_blockwise(assume_a, counter, transposed):
129137
b = tensor("b", shape=(2, 3, 4))
130138

131139
x = solve(A, b, assume_a=assume_a, transposed=transposed)
132-
fn_no_opt = function([A, b], [x], mode=mode.excluding(rewrite_name))
140+
fn_no_opt = function(
141+
[A, b], [x], mode=mode.excluding(rewrite_name, "psd_solve_to_chol_solve")
142+
)
133143
no_opt_nodes = fn_no_opt.maker.fgraph.apply_nodes
134144
assert counter.count_vanilla_solve_nodes(no_opt_nodes) == 1
135145
assert counter.count_decomp_nodes(no_opt_nodes) == 0
136146
assert counter.count_solve_nodes(no_opt_nodes) == 0
137147

138-
fn_opt = function([A, b], [x], mode=mode.including(rewrite_name))
148+
fn_opt = function(
149+
[A, b],
150+
[x],
151+
mode=mode.including(rewrite_name).excluding("psd_solve_to_chol_solve"),
152+
)
139153
opt_nodes = fn_opt.maker.fgraph.apply_nodes
140154
assert counter.count_vanilla_solve_nodes(opt_nodes) == 0
141155
assert counter.count_decomp_nodes(opt_nodes) == 1
@@ -176,6 +190,7 @@ def test_lu_decomposition_reused_scan(assume_a, counter, transposed):
176190
non_sequences=[A],
177191
n_steps=10,
178192
return_updates=False,
193+
mode=get_default_mode().excluding("psd_solve_to_chol_solve"),
179194
)
180195

181196
fn_no_opt = function(

0 commit comments

Comments
 (0)