|
15 | 15 | from pytensor.tensor.elemwise import DimShuffle |
16 | 16 | from pytensor.tensor.rewriting.basic import register_specialize |
17 | 17 | from pytensor.tensor.rewriting.blockwise import blockwise_of |
18 | | -from pytensor.tensor.rewriting.linalg import is_matrix_transpose |
19 | 18 | from pytensor.tensor.slinalg import Solve, cho_solve, cholesky, lu_factor, lu_solve |
20 | 19 | from pytensor.tensor.variable import TensorVariable |
21 | 20 |
|
@@ -79,14 +78,16 @@ def get_root_A(a: TensorVariable) -> tuple[TensorVariable, bool]: |
79 | 78 | # the root variable is the pre-DimShuffled input. |
80 | 79 | # Otherwise, `a` is considered the root variable. |
81 | 80 | # We also return whether the root `a` is transposed. |
82 | | - transposed = False |
83 | | - if a.owner is not None and isinstance(a.owner.op, DimShuffle): |
84 | | - if a.owner.op.is_left_expand_dims: |
85 | | - [a] = a.owner.inputs |
86 | | - elif is_matrix_transpose(a): |
87 | | - [a] = a.owner.inputs |
| 81 | + match a.owner_op_and_inputs: |
| 82 | + case (DimShuffle(is_left_expand_dims=True), root_a): |
| 83 | + transposed = False |
| 84 | + case (DimShuffle(is_left_expanded_matrix_transpose=True), root_a): |
88 | 85 | transposed = True |
89 | | - return a, transposed |
| 86 | + case _: |
| 87 | + root_a = a |
| 88 | + transposed = False |
| 89 | + |
| 90 | + return root_a, transposed |
90 | 91 |
|
91 | 92 | def find_solve_clients(var, assume_a): |
92 | 93 | clients = [] |
@@ -119,11 +120,11 @@ def find_solve_clients(var, assume_a): |
119 | 120 |
|
120 | 121 | # Find Solves using A.T |
121 | 122 | for cl, _ in fgraph.clients[A]: |
122 | | - if isinstance(cl.op, DimShuffle) and is_matrix_transpose(cl.out): |
123 | | - A_T = cl.out |
124 | | - A_solve_clients_and_transpose.extend( |
125 | | - (client, True) for client in find_solve_clients(A_T, assume_a) |
126 | | - ) |
| 123 | + match (cl.op, *cl.outputs): |
| 124 | + case (DimShuffle(is_left_expanded_matrix_transpose=True), A_T): |
| 125 | + A_solve_clients_and_transpose.extend( |
| 126 | + (client, True) for client in find_solve_clients(A_T, assume_a) |
| 127 | + ) |
127 | 128 |
|
128 | 129 | if not eager and len(A_solve_clients_and_transpose) == 1: |
129 | 130 | # If theres' a single use don't do it... unless it's being broadcast in a Blockwise (or we're eager) |
|
0 commit comments