Skip to content

Commit 0a5ec73

Browse files
committed
Simplify linalg rewrites with pattern matching
1 parent ca492f9 commit 0a5ec73

File tree

4 files changed

+338
-573
lines changed

4 files changed

+338
-573
lines changed

pytensor/tensor/_linalg/solve/rewriting.py

Lines changed: 14 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,6 @@
1515
from pytensor.tensor.elemwise import DimShuffle
1616
from pytensor.tensor.rewriting.basic import register_specialize
1717
from pytensor.tensor.rewriting.blockwise import blockwise_of
18-
from pytensor.tensor.rewriting.linalg import is_matrix_transpose
1918
from pytensor.tensor.slinalg import Solve, cho_solve, cholesky, lu_factor, lu_solve
2019
from pytensor.tensor.variable import TensorVariable
2120

@@ -79,14 +78,16 @@ def get_root_A(a: TensorVariable) -> tuple[TensorVariable, bool]:
7978
# the root variable is the pre-DimShuffled input.
8079
# Otherwise, `a` is considered the root variable.
8180
# We also return whether the root `a` is transposed.
82-
transposed = False
83-
if a.owner is not None and isinstance(a.owner.op, DimShuffle):
84-
if a.owner.op.is_left_expand_dims:
85-
[a] = a.owner.inputs
86-
elif is_matrix_transpose(a):
87-
[a] = a.owner.inputs
81+
match a.owner_op_and_inputs:
82+
case (DimShuffle(is_left_expand_dims=True), root_a):
83+
transposed = False
84+
case (DimShuffle(is_left_expanded_matrix_transpose=True), root_a):
8885
transposed = True
89-
return a, transposed
86+
case _:
87+
root_a = a
88+
transposed = False
89+
90+
return root_a, transposed
9091

9192
def find_solve_clients(var, assume_a):
9293
clients = []
@@ -119,11 +120,11 @@ def find_solve_clients(var, assume_a):
119120

120121
# Find Solves using A.T
121122
for cl, _ in fgraph.clients[A]:
122-
if isinstance(cl.op, DimShuffle) and is_matrix_transpose(cl.out):
123-
A_T = cl.out
124-
A_solve_clients_and_transpose.extend(
125-
(client, True) for client in find_solve_clients(A_T, assume_a)
126-
)
123+
match (cl.op, *cl.outputs):
124+
case (DimShuffle(is_left_expanded_matrix_transpose=True), A_T):
125+
A_solve_clients_and_transpose.extend(
126+
(client, True) for client in find_solve_clients(A_T, assume_a)
127+
)
127128

128129
if not eager and len(A_solve_clients_and_transpose) == 1:
129130
# If theres' a single use don't do it... unless it's being broadcast in a Blockwise (or we're eager)

0 commit comments

Comments
 (0)