From aefca9bc38e8468d33bc2a232adab2132ab6f5d6 Mon Sep 17 00:00:00 2001 From: ricardoV94 Date: Fri, 12 Dec 2025 08:50:31 +0100 Subject: [PATCH 1/4] Remove unnecessary dunder methods from Variable --- pytensor/graph/basic.py | 20 -------------------- 1 file changed, 20 deletions(-) diff --git a/pytensor/graph/basic.py b/pytensor/graph/basic.py index 1ee46f3449..71a9905c09 100644 --- a/pytensor/graph/basic.py +++ b/pytensor/graph/basic.py @@ -553,26 +553,6 @@ def clone(self, **kwargs): cp.tag = copy(self.tag) return cp - def __lt__(self, other): - raise NotImplementedError( - "Subclasses of Variable must provide __lt__", self.__class__.__name__ - ) - - def __le__(self, other): - raise NotImplementedError( - "Subclasses of Variable must provide __le__", self.__class__.__name__ - ) - - def __gt__(self, other): - raise NotImplementedError( - "Subclasses of Variable must provide __gt__", self.__class__.__name__ - ) - - def __ge__(self, other): - raise NotImplementedError( - "Subclasses of Variable must provide __ge__", self.__class__.__name__ - ) - def get_parents(self): if self.owner is not None: return [self.owner] From d5a8c0af585bc71094d5f19cb81778ef696f478e Mon Sep 17 00:00:00 2001 From: ricardoV94 Date: Fri, 12 Dec 2025 08:51:02 +0100 Subject: [PATCH 2/4] Facilitate graph pattern matching --- pytensor/graph/basic.py | 18 +++++++ pytensor/link/jax/dispatch/elemwise.py | 2 +- pytensor/link/mlx/dispatch/elemwise.py | 2 +- pytensor/link/pytorch/dispatch/elemwise.py | 2 +- pytensor/tensor/blockwise.py | 2 + pytensor/tensor/elemwise.py | 56 +++++++++++++-------- pytensor/tensor/rewriting/subtensor_lift.py | 2 +- tests/scan/test_printing.py | 4 +- 8 files changed, 61 insertions(+), 27 deletions(-) diff --git a/pytensor/graph/basic.py b/pytensor/graph/basic.py index 71a9905c09..e406d8598c 100644 --- a/pytensor/graph/basic.py +++ b/pytensor/graph/basic.py @@ -558,6 +558,22 @@ def get_parents(self): return [self.owner] return [] + @property + def owner_op(self) -> Optional["Op"]: + if (apply := self.owner) is not None: + return apply.op # type: ignore[no-any-return] + else: + return None + + @property + def owner_op_and_inputs( + self, + ) -> tuple[Optional["Op"], *tuple["Variable", ...]]: + if (apply := self.owner) is not None: + return apply.op, *apply.inputs # type: ignore[has-type] + else: + return (None,) + def eval( self, inputs_to_values: dict[Union["Variable", str], Any] | None = None, @@ -773,6 +789,8 @@ class Constant(AtomicVariable[_TypeType]): """ # __slots__ = ['data'] + # Allow pattern matching on data field positionally + __match_args__ = ("data",) def __init__(self, type: _TypeType, data: Any, name: str | None = None): super().__init__(type, name=name) diff --git a/pytensor/link/jax/dispatch/elemwise.py b/pytensor/link/jax/dispatch/elemwise.py index d4c8e7b605..6ae38728bd 100644 --- a/pytensor/link/jax/dispatch/elemwise.py +++ b/pytensor/link/jax/dispatch/elemwise.py @@ -72,7 +72,7 @@ def careduce(x): @jax_funcify.register(DimShuffle) def jax_funcify_DimShuffle(op, **kwargs): def dimshuffle(x): - res = jnp.transpose(x, op.transposition) + res = jnp.transpose(x, op._transposition) shape = list(res.shape[: len(op.shuffle)]) diff --git a/pytensor/link/mlx/dispatch/elemwise.py b/pytensor/link/mlx/dispatch/elemwise.py index 19af7cd70c..c113711101 100644 --- a/pytensor/link/mlx/dispatch/elemwise.py +++ b/pytensor/link/mlx/dispatch/elemwise.py @@ -52,7 +52,7 @@ def dimshuffle(x): isinstance(x, np.number) and not isinstance(x, np.ndarray) ): x = mx.array(x) - res = mx.transpose(x, op.transposition) + res = mx.transpose(x, op._transposition) shape = list(res.shape[: len(op.shuffle)]) for augm in op.augment: shape.insert(augm, 1) diff --git a/pytensor/link/pytorch/dispatch/elemwise.py b/pytensor/link/pytorch/dispatch/elemwise.py index a3b7683004..eac4a552c1 100644 --- a/pytensor/link/pytorch/dispatch/elemwise.py +++ b/pytensor/link/pytorch/dispatch/elemwise.py @@ -54,7 +54,7 @@ def elemwise_fn(*inputs): @pytorch_funcify.register(DimShuffle) def pytorch_funcify_DimShuffle(op, **kwargs): def dimshuffle(x): - res = torch.permute(x, op.transposition) + res = torch.permute(x, op._transposition) shape = list(res.shape[: len(op.shuffle)]) diff --git a/pytensor/tensor/blockwise.py b/pytensor/tensor/blockwise.py index 0181699851..6d9a68f5a7 100644 --- a/pytensor/tensor/blockwise.py +++ b/pytensor/tensor/blockwise.py @@ -160,6 +160,8 @@ class Blockwise(COp): """ __props__ = ("core_op", "signature") + # Allow pattern matching on core_op positionally + __match_args__ = ("core_op",) def __init__( self, diff --git a/pytensor/tensor/elemwise.py b/pytensor/tensor/elemwise.py index f1d8bc09df..aa1e369ffd 100644 --- a/pytensor/tensor/elemwise.py +++ b/pytensor/tensor/elemwise.py @@ -132,7 +132,7 @@ def __init__(self, *, input_ndim: int, new_order: Sequence[int | Literal["x"]]): raise TypeError(f"input_ndim must be an integer, got {type(int)}") self.input_ndim = input_ndim - self.new_order = tuple(new_order) + self.new_order = new_order = tuple(new_order) self._new_order = [(-1 if x == "x" else x) for x in self.new_order] for i, j in enumerate(new_order): @@ -153,28 +153,38 @@ def __init__(self, *, input_ndim: int, new_order: Sequence[int | Literal["x"]]): f"twice in the list of output dimensions: {new_order}" ) - # List of input dimensions to drop - drop = [i for i in range(input_ndim) if i not in new_order] - - # This is the list of the original dimensions that we keep - self.shuffle = [x for x in new_order if x != "x"] - self.transposition = self.shuffle + drop - # List of dimensions of the output that are broadcastable and were not - # in the original input - self.augment = augment = sorted(i for i, x in enumerate(new_order) if x == "x") - self.drop = drop + # Tuple of the original dimensions that we keep + self.shuffle = tuple(x for x in new_order if x != "x") + # Tuple of input dimensions to drop + self.drop = drop = tuple(i for i in range(input_ndim) if i not in new_order) + # tuple of dimensions of the output that are broadcastable and were not in the original input + self.augment = augment = tuple( + sorted(i for i, x in enumerate(new_order) if x == "x") + ) + n_augment = len(self.augment) - dims_are_shuffled = sorted(self.shuffle) != self.shuffle + # Used by perform + self._transposition = self.shuffle + drop - self.is_transpose = dims_are_shuffled and not augment and not drop - self.is_squeeze = drop and not dims_are_shuffled and not augment - self.is_expand_dims = augment and not dims_are_shuffled and not drop - self.is_left_expand_dims = self.is_expand_dims and ( - input_ndim == 0 or new_order[-input_ndim:] == list(range(input_ndim)) + # Classify the type of dimshuffle for rewrite purposes + dims_are_shuffled = tuple(sorted(self.shuffle)) != self.shuffle + self.is_squeeze = drop and not augment and not dims_are_shuffled + self.is_expand_dims = is_expand_dims = ( + not drop and augment and not dims_are_shuffled + ) + self.is_left_expand_dims = is_expand_dims and new_order[n_augment:] == tuple( + range(input_ndim) + ) + self.is_right_expand_dims = is_expand_dims and new_order[:input_ndim] == tuple( + range(input_ndim) + ) + self.is_transpose = not drop and not augment and dims_are_shuffled + self.is_left_expanded_matrix_transpose = is_left_expanded_matrix_transpose = ( + dims_are_shuffled + and new_order[n_augment:] + == (*range(input_ndim - 2), input_ndim - 1, input_ndim - 2) ) - self.is_right_expand_dims = self.is_expand_dims and new_order[ - :input_ndim - ] == list(range(input_ndim)) + self.is_matrix_transpose = not augment and is_left_expanded_matrix_transpose def __setstate__(self, state): self.__dict__.update(state) @@ -212,6 +222,8 @@ def make_node(self, inp): return Apply(self, [input], [output]) def __str__(self): + if self.is_matrix_transpose: + return "MatrixTranspose" if self.is_expand_dims: if len(self.augment) == 1: return f"ExpandDims{{axis={self.augment[0]}}}" @@ -237,7 +249,7 @@ def perform(self, node, inp, out): # ) # Put dropped axis at end - res = res.transpose(self.transposition) + res = res.transpose(self._transposition) # Define new shape without dropped axis and including new ones new_shape = list(res.shape[: len(self.shuffle)]) @@ -330,6 +342,8 @@ class Elemwise(OpenMPOp): """ __props__ = ("scalar_op", "inplace_pattern") + # Allow pattern matching on scalar_op positionally + __match_args__ = ("scalar_op",) def __init__( self, scalar_op, inplace_pattern=None, name=None, nfunc_spec=None, openmp=None diff --git a/pytensor/tensor/rewriting/subtensor_lift.py b/pytensor/tensor/rewriting/subtensor_lift.py index 4d0a8cd5cb..75a3cc43ce 100644 --- a/pytensor/tensor/rewriting/subtensor_lift.py +++ b/pytensor/tensor/rewriting/subtensor_lift.py @@ -521,7 +521,7 @@ def local_subtensor_of_transpose(fgraph, node): if not ds_op.is_transpose: return None - transposition = ds_op.transposition + transposition = ds_op._transposition [x] = ds.owner.inputs idx_tuple = indices_from_subtensor(idx, node.op.idx_list) diff --git a/tests/scan/test_printing.py b/tests/scan/test_printing.py index 5a7b45becc..8c0ee9bfa1 100644 --- a/tests/scan/test_printing.py +++ b/tests/scan/test_printing.py @@ -555,7 +555,7 @@ def test_debugprint_mitmot(): │ │ │ ├─ Second [id BL] │ │ │ │ ├─ Scan{scan_fn, while_loop=False, inplace=none} [id F] (outer_out_sit_sot-0) │ │ │ │ │ └─ ··· - │ │ │ │ └─ ExpandDims{axes=[0, 1]} [id BM] + │ │ │ │ └─ ExpandDims{axes=(0, 1)} [id BM] │ │ │ │ └─ 0.0 [id BN] │ │ │ ├─ IncSubtensor{i} [id BO] │ │ │ │ ├─ Second [id BP] @@ -563,7 +563,7 @@ def test_debugprint_mitmot(): │ │ │ │ │ │ ├─ Scan{scan_fn, while_loop=False, inplace=none} [id F] (outer_out_sit_sot-0) │ │ │ │ │ │ │ └─ ··· │ │ │ │ │ │ └─ 1 [id BR] - │ │ │ │ │ └─ ExpandDims{axes=[0, 1]} [id BS] + │ │ │ │ │ └─ ExpandDims{axes=(0, 1)} [id BS] │ │ │ │ │ └─ 0.0 [id BT] │ │ │ │ ├─ Second [id BU] │ │ │ │ │ ├─ Subtensor{i} [id BV] From 7dbc03415d875982d424e6df2d8190aa210b7b56 Mon Sep 17 00:00:00 2001 From: ricardoV94 Date: Fri, 12 Dec 2025 14:50:49 +0100 Subject: [PATCH 3/4] Move non-rewrite test --- tests/tensor/rewriting/test_linalg.py | 52 +-------------------------- tests/tensor/test_nlinalg.py | 50 +++++++++++++++++++++++++- 2 files changed, 50 insertions(+), 52 deletions(-) diff --git a/tests/tensor/rewriting/test_linalg.py b/tests/tensor/rewriting/test_linalg.py index 37b8afb30a..7828acf4ca 100644 --- a/tests/tensor/rewriting/test_linalg.py +++ b/tests/tensor/rewriting/test_linalg.py @@ -39,57 +39,7 @@ solve, solve_triangular, ) -from pytensor.tensor.type import dmatrix, matrix, tensor, vector -from tests import unittest_tools as utt -from tests.test_rop import break_op - - -def test_matrix_inverse_rop_lop(): - rtol = 1e-7 if config.floatX == "float64" else 1e-5 - mx = matrix("mx") - mv = matrix("mv") - v = vector("v") - y = MatrixInverse()(mx).sum(axis=0) - - yv = pytensor.gradient.Rop(y, mx, mv, use_op_rop_implementation=True) - rop_f = function([mx, mv], yv) - - yv_via_lop = pytensor.gradient.Rop(y, mx, mv, use_op_rop_implementation=False) - rop_via_lop_f = function([mx, mv], yv_via_lop) - - sy, _ = pytensor.scan( - lambda i, y, x, v: (pytensor.gradient.grad(y[i], x) * v).sum(), - sequences=pt.arange(y.shape[0]), - non_sequences=[y, mx, mv], - ) - scan_f = function([mx, mv], sy) - - rng = np.random.default_rng(utt.fetch_seed()) - vx = np.asarray(rng.standard_normal((4, 4)), pytensor.config.floatX) - vv = np.asarray(rng.standard_normal((4, 4)), pytensor.config.floatX) - - v_ref = scan_f(vx, vv) - np.testing.assert_allclose(rop_f(vx, vv), v_ref, rtol=rtol) - np.testing.assert_allclose(rop_via_lop_f(vx, vv), v_ref, rtol=rtol) - - with pytest.raises(ValueError): - pytensor.gradient.Rop( - pytensor.clone_replace(y, replace={mx: break_op(mx)}), - mx, - mv, - use_op_rop_implementation=True, - ) - - vv = np.asarray(rng.uniform(size=(4,)), pytensor.config.floatX) - yv = pytensor.gradient.Lop(y, mx, v) - lop_f = function([mx, v], yv) - - sy = pytensor.gradient.grad((v * y).sum(), mx) - scan_f = function([mx, v], sy) - - v_ref = scan_f(vx, vv) - v = lop_f(vx, vv) - np.testing.assert_allclose(v, v_ref, rtol=rtol) +from pytensor.tensor.type import dmatrix, matrix, tensor def test_transinv_to_invtrans(): diff --git a/tests/tensor/test_nlinalg.py b/tests/tensor/test_nlinalg.py index 840596fbff..e3064bc11b 100644 --- a/tests/tensor/test_nlinalg.py +++ b/tests/tensor/test_nlinalg.py @@ -7,7 +7,7 @@ import pytensor from pytensor import function from pytensor.configdefaults import config -from pytensor.tensor.basic import as_tensor_variable +from pytensor.tensor.basic import arange, as_tensor_variable from pytensor.tensor.math import _allclose from pytensor.tensor.nlinalg import ( SVD, @@ -41,6 +41,7 @@ vector, ) from tests import unittest_tools as utt +from tests.test_rop import break_op def test_pseudoinverse_correctness(): @@ -101,6 +102,53 @@ def test_infer_shape(self): self._compile_and_check([x], [xi], [r], self.op_class, warn=False) + def test_rop_lop(self): + rtol = 1e-7 if config.floatX == "float64" else 1e-5 + mx = matrix("mx") + mv = matrix("mv") + v = vector("v") + y = MatrixInverse()(mx).sum(axis=0) + + yv = pytensor.gradient.Rop(y, mx, mv, use_op_rop_implementation=True) + rop_f = function([mx, mv], yv) + + yv_via_lop = pytensor.gradient.Rop(y, mx, mv, use_op_rop_implementation=False) + rop_via_lop_f = function([mx, mv], yv_via_lop) + + sy, _ = pytensor.scan( + lambda i, y, x, v: (pytensor.gradient.grad(y[i], x) * v).sum(), + sequences=arange(y.shape[0]), + non_sequences=[y, mx, mv], + ) + scan_f = function([mx, mv], sy) + + rng = np.random.default_rng(utt.fetch_seed()) + vx = np.asarray(rng.standard_normal((4, 4)), pytensor.config.floatX) + vv = np.asarray(rng.standard_normal((4, 4)), pytensor.config.floatX) + + v_ref = scan_f(vx, vv) + np.testing.assert_allclose(rop_f(vx, vv), v_ref, rtol=rtol) + np.testing.assert_allclose(rop_via_lop_f(vx, vv), v_ref, rtol=rtol) + + with pytest.raises(ValueError): + pytensor.gradient.Rop( + pytensor.clone_replace(y, replace={mx: break_op(mx)}), + mx, + mv, + use_op_rop_implementation=True, + ) + + vv = np.asarray(rng.uniform(size=(4,)), pytensor.config.floatX) + yv = pytensor.gradient.Lop(y, mx, v) + lop_f = function([mx, v], yv) + + sy = pytensor.gradient.grad((v * y).sum(), mx) + scan_f = function([mx, v], sy) + + v_ref = scan_f(vx, vv) + v = lop_f(vx, vv) + np.testing.assert_allclose(v, v_ref, rtol=rtol) + def test_matrix_dot(): rng = np.random.default_rng(utt.fetch_seed()) From e0bcea702505329217fc332de33b56337954726f Mon Sep 17 00:00:00 2001 From: ricardoV94 Date: Fri, 12 Dec 2025 14:29:13 +0100 Subject: [PATCH 4/4] Simplify linalg rewrites with pattern matching --- pytensor/tensor/_linalg/solve/rewriting.py | 97 ++- pytensor/tensor/rewriting/linalg.py | 807 ++++++++------------- pytensor/tensor/rewriting/math.py | 6 +- tests/tensor/linalg/test_rewriting.py | 23 +- tests/tensor/rewriting/test_linalg.py | 74 +- 5 files changed, 390 insertions(+), 617 deletions(-) diff --git a/pytensor/tensor/_linalg/solve/rewriting.py b/pytensor/tensor/_linalg/solve/rewriting.py index 096acd6602..6e89556685 100644 --- a/pytensor/tensor/_linalg/solve/rewriting.py +++ b/pytensor/tensor/_linalg/solve/rewriting.py @@ -15,7 +15,6 @@ from pytensor.tensor.elemwise import DimShuffle from pytensor.tensor.rewriting.basic import register_specialize from pytensor.tensor.rewriting.blockwise import blockwise_of -from pytensor.tensor.rewriting.linalg import is_matrix_transpose from pytensor.tensor.slinalg import Solve, cho_solve, cholesky, lu_factor, lu_solve from pytensor.tensor.variable import TensorVariable @@ -79,28 +78,26 @@ def get_root_A(a: TensorVariable) -> tuple[TensorVariable, bool]: # the root variable is the pre-DimShuffled input. # Otherwise, `a` is considered the root variable. # We also return whether the root `a` is transposed. + root_a = a transposed = False - if a.owner is not None and isinstance(a.owner.op, DimShuffle): - if a.owner.op.is_left_expand_dims: - [a] = a.owner.inputs - elif is_matrix_transpose(a): - [a] = a.owner.inputs - transposed = True - return a, transposed + match a.owner_op_and_inputs: + case (DimShuffle(is_left_expand_dims=True), root_a): # type: ignore[misc] + transposed = False + case (DimShuffle(is_left_expanded_matrix_transpose=True), root_a): # type: ignore[misc] + transposed = True # type: ignore[unreachable] + + return root_a, transposed def find_solve_clients(var, assume_a): clients = [] for cl, idx in fgraph.clients[var]: - if ( - idx == 0 - and isinstance(cl.op, Blockwise) - and isinstance(cl.op.core_op, Solve) - and (cl.op.core_op.assume_a == assume_a) - ): - clients.append(cl) - elif isinstance(cl.op, DimShuffle) and cl.op.is_left_expand_dims: - # If it's a left expand_dims, recurse on the output - clients.extend(find_solve_clients(cl.outputs[0], assume_a)) + match (idx, cl.op, *cl.outputs): + case (0, Blockwise(Solve(assume_a=assume_a_var)), *_) if ( + assume_a_var == assume_a + ): + clients.append(cl) + case (0, DimShuffle(is_left_expand_dims=True), cl_out): + clients.extend(find_solve_clients(cl_out, assume_a)) return clients assume_a = node.op.core_op.assume_a @@ -119,11 +116,11 @@ def find_solve_clients(var, assume_a): # Find Solves using A.T for cl, _ in fgraph.clients[A]: - if isinstance(cl.op, DimShuffle) and is_matrix_transpose(cl.out): - A_T = cl.out - A_solve_clients_and_transpose.extend( - (client, True) for client in find_solve_clients(A_T, assume_a) - ) + match (cl.op, *cl.outputs): + case (DimShuffle(is_left_expanded_matrix_transpose=True), A_T): + A_solve_clients_and_transpose.extend( + (client, True) for client in find_solve_clients(A_T, assume_a) + ) if not eager and len(A_solve_clients_and_transpose) == 1: # If theres' a single use don't do it... unless it's being broadcast in a Blockwise (or we're eager) @@ -185,34 +182,34 @@ def _scan_split_non_sequence_decomposition_and_solve( changed = False while True: for inner_node in new_scan_fgraph.toposort(): - if ( - isinstance(inner_node.op, Blockwise) - and isinstance(inner_node.op.core_op, Solve) - and inner_node.op.core_op.assume_a in allowed_assume_a - ): - A, _b = inner_node.inputs - if all( - (isinstance(root_inp, Constant) or (root_inp in non_sequences)) - for root_inp in graph_inputs([A]) + match (inner_node.op, *inner_node.inputs): + case (Blockwise(Solve(assume_a=assume_a_var)), A, _b) if ( + assume_a_var in allowed_assume_a ): - if new_scan_fgraph is scan_op.fgraph: - # Clone the first time to avoid mutating the original fgraph - new_scan_fgraph, equiv = new_scan_fgraph.clone_get_equiv() - non_sequences = {equiv[non_seq] for non_seq in non_sequences} - inner_node = equiv[inner_node] # type: ignore - - replace_dict = _split_decomp_and_solve_steps( - new_scan_fgraph, - inner_node, - eager=True, - allowed_assume_a=allowed_assume_a, - ) - assert isinstance(replace_dict, dict) and len(replace_dict) > 0, ( - "Rewrite failed" - ) - new_scan_fgraph.replace_all(replace_dict.items()) - changed = True - break # Break to start over with a fresh toposort + if all( + (isinstance(root_inp, Constant) or (root_inp in non_sequences)) + for root_inp in graph_inputs([A]) + ): + if new_scan_fgraph is scan_op.fgraph: + # Clone the first time to avoid mutating the original fgraph + new_scan_fgraph, equiv = new_scan_fgraph.clone_get_equiv() + non_sequences = { + equiv[non_seq] for non_seq in non_sequences + } + inner_node = equiv[inner_node] # type: ignore + + replace_dict = _split_decomp_and_solve_steps( + new_scan_fgraph, + inner_node, + eager=True, + allowed_assume_a=allowed_assume_a, + ) + assert ( + isinstance(replace_dict, dict) and len(replace_dict) > 0 + ), "Rewrite failed" + new_scan_fgraph.replace_all(replace_dict.items()) + changed = True + break # Break to start over with a fresh toposort else: # no_break break # Nothing else changed diff --git a/pytensor/tensor/rewriting/linalg.py b/pytensor/tensor/rewriting/linalg.py index 17a3ce9165..c66786a934 100644 --- a/pytensor/tensor/rewriting/linalg.py +++ b/pytensor/tensor/rewriting/linalg.py @@ -1,32 +1,30 @@ import logging -from collections.abc import Callable -from typing import cast import numpy as np -from pytensor import Variable from pytensor import tensor as pt from pytensor.compile import optdb -from pytensor.graph import Apply, FunctionGraph +from pytensor.graph import Apply, Constant, FunctionGraph from pytensor.graph.rewriting.basic import ( copy_stack_trace, dfs_rewriter, node_rewriter, ) from pytensor.graph.rewriting.unify import OpPattern -from pytensor.scalar.basic import Abs, Log, Mul, Sign +from pytensor.scalar.basic import Abs, Exp, Log, Mul, Sign, Sqr from pytensor.tensor.basic import ( AllocDiag, ExtractDiag, Eye, TensorVariable, + atleast_Nd, concatenate, diag, diagonal, ) from pytensor.tensor.blockwise import Blockwise from pytensor.tensor.elemwise import DimShuffle, Elemwise -from pytensor.tensor.math import Dot, Prod, _matmul, log, outer, prod +from pytensor.tensor.math import Dot, Prod, log, outer, prod, variadic_mul from pytensor.tensor.nlinalg import ( SVD, KroneckerProduct, @@ -34,9 +32,7 @@ MatrixPinv, SLogDet, det, - inv, kron, - pinv, svd, ) from pytensor.tensor.rewriting.basic import ( @@ -62,71 +58,87 @@ logger = logging.getLogger(__name__) +# TODO: Make this inherit from a common abstract base class MATRIX_INVERSE_OPS = (MatrixInverse, MatrixPinv) def is_matrix_transpose(x: TensorVariable) -> bool: """Check if a variable corresponds to a transpose of the last two axes""" - node = x.owner - if ( - node - and isinstance(node.op, DimShuffle) - and not (node.op.drop or node.op.augment) - ): - [inp] = node.inputs - ndims = inp.type.ndim - if ndims < 2: - return False - transpose_order = (*range(ndims - 2), ndims - 1, ndims - 2) - - # Allow expand_dims on the left of the transpose - if (diff := len(transpose_order) - len(node.op.new_order)) > 0: - transpose_order = ( - *(["x"] * diff), - *transpose_order, - ) - return node.op.new_order == transpose_order + match x.owner_op: + case DimShuffle(is_left_expanded_matrix_transpose=True): + return True return False +def is_eye_mul(x) -> None | tuple[TensorVariable, TensorVariable]: + # Check if we have a Multiplication with an Eye inside + # Note: This matches for cases like (eye * 0)! + match x.owner_op_and_inputs: + case Elemwise(Mul()), *mul_inputs: + pass + case _: + return None + + x_bcast = x.type.broadcastable[-2:] + eye_input = None + non_eye_inputs = [] + for mul_input in mul_inputs: + # We only care about Eye if it's not broadcasting in the multiplication + if mul_input.type.broadcastable[-2:] == x_bcast: + match mul_input.owner_op_and_inputs: + case (Eye(), _, _, Constant(0)): + eye_input = mul_input + continue + # This whole condition checks if there is an Eye hiding inside a DimShuffle. + # This arises from batched elementwise multiplication between a tensor and an eye, e.g.: + # tensor(shape=(None, 3, 3) * eye(3). This is still potentially valid for diag rewrites. + case (DimShuffle(is_left_expand_dims=True), ds_input): + match ds_input.owner_op_and_inputs: + case (Eye(), _, _, Constant(0)): + eye_input = mul_input + continue + # If no match: + non_eye_inputs.append(mul_input) + + if eye_input is None: + return None + + return eye_input, variadic_mul(*non_eye_inputs) + + @register_canonicalize -@node_rewriter([DimShuffle]) -def transinv_to_invtrans(fgraph, node): - if is_matrix_transpose(node.outputs[0]): - (A,) = node.inputs - if ( - A.owner - and isinstance(A.owner.op, Blockwise) - and isinstance(A.owner.op.core_op, MatrixInverse) - ): - (X,) = A.owner.inputs - return [A.owner.op(node.op(X))] +@node_rewriter([OpPattern(DimShuffle, is_left_expanded_matrix_transpose=True)]) +def transpose_of_inv(fgraph, node): + # TODO: Transpose is much more frequent that MatrixInverse, flip the rewrite pattern matching. + [A] = node.inputs + match A.owner_op_and_inputs: + case (Blockwise(MatrixInverse()) as inv_op, X): + return [inv_op(node.op(X))] @register_stabilize @node_rewriter([Dot]) -def inv_as_solve(fgraph, node): +def inv_to_solve(fgraph, node): """ This utilizes a boolean `symmetric` tag on the matrices. + + TODO: Exploit other assumptions like 'triangular' and 'psd' + TODO: Handle expand_dims / matrix transpose between inv and dot """ - if isinstance(node.op, Dot): - l, r = node.inputs - if ( - l.owner - and isinstance(l.owner.op, Blockwise) - and isinstance(l.owner.op.core_op, MatrixInverse) - ): - return [solve(l.owner.inputs[0], r)] - if ( - r.owner - and isinstance(r.owner.op, Blockwise) - and isinstance(r.owner.op.core_op, MatrixInverse) - ): - x = r.owner.inputs[0] - if getattr(x.tag, "symmetric", None) is True: - return [solve(x, (l.mT)).mT] + l, r = node.inputs + match l.owner_op_and_inputs: + case (Blockwise(MatrixInverse()), X): + assume_a = "sym" if getattr(X.tag, "symmetric", False) else "gen" + return [solve(X, r, assume_a=assume_a)] + + match r.owner_op_and_inputs: + case (Blockwise(MatrixInverse()), X): + if getattr(X.tag, "symmetric", False): + return [solve(X, (l.mT), assume_a="sym").mT] else: - return [solve((x.mT), (l.mT)).mT] + return [solve((X.mT), (l.mT)).mT] + + return None @register_stabilize @@ -138,31 +150,15 @@ def generic_solve_to_solve_triangular(fgraph, node): replace it with a triangular solve. """ + b_ndim = node.op.core_op.b_ndim A, b = node.inputs # result is the solution to Ax=b - if ( - A.owner - and isinstance(A.owner.op, Blockwise) - and isinstance(A.owner.op.core_op, Cholesky) - ): - if A.owner.op.core_op.lower: - return [solve_triangular(A, b, lower=True, b_ndim=node.op.core_op.b_ndim)] - else: - return [solve_triangular(A, b, lower=False, b_ndim=node.op.core_op.b_ndim)] - if is_matrix_transpose(A): - (A_T,) = A.owner.inputs - if ( - A_T.owner - and isinstance(A_T.owner.op, Blockwise) - and isinstance(A_T.owner.op, Cholesky) - ): - if A_T.owner.op.lower: - return [ - solve_triangular(A, b, lower=False, b_ndim=node.op.core_op.b_ndim) - ] - else: - return [ - solve_triangular(A, b, lower=True, b_ndim=node.op.core_op.b_ndim) - ] + match A.owner_op_and_inputs: + case (Blockwise(Cholesky(lower=lower)), _): + return [solve_triangular(A, b, lower=lower, b_ndim=b_ndim)] + case (DimShuffle(is_left_expanded_matrix_transpose=True), A_T): + match A_T.owner_op: + case Blockwise(Cholesky(lower=lower)): + return [solve_triangular(A, b, lower=not lower, b_ndim=b_ndim)] @register_specialize @@ -184,7 +180,7 @@ def batched_vector_b_solve_to_matrix_b_solve(fgraph, node): if not all(a_bcast_batch_dims): return None # We squeeze degenerate dims, any that are still needed will be introduced by the new_solve - elif len(a_bcast_batch_dims): + elif a_bcast_batch_dims: a = a.squeeze(axis=tuple(range(len(a_bcast_batch_dims)))) # Recreate solve Op with b_ndim=2 @@ -214,22 +210,22 @@ def batched_vector_b_solve_to_matrix_b_solve(fgraph, node): @register_canonicalize @register_stabilize @register_specialize -@node_rewriter([DimShuffle]) -def no_transpose_symmetric(fgraph, node): - if is_matrix_transpose(node.outputs[0]): - x = node.inputs[0] - if getattr(x.tag, "symmetric", None): - return [x] +@node_rewriter([OpPattern(DimShuffle, is_left_expanded_matrix_transpose=True)]) +def useless_symmetric_transpose(fgraph, node): + x = node.inputs[0] + if getattr(x.tag, "symmetric", False): + return [atleast_Nd(x, n=node.outputs[0].type.ndim)] @register_stabilize @node_rewriter([blockwise_of(OpPattern(Solve, b_ndim=2))]) -def psd_solve_with_chol(fgraph, node): +def psd_solve_to_chol_solve(fgraph, node): """ - This utilizes a boolean `psd` tag on matrices. + This utilizes the Solve assume_a flag or a boolean `psd` tag on matrices. """ + assume_a = node.op.core_op.assume_a A, b = node.inputs # result is the solution to Ax=b - if getattr(A.tag, "psd", None) is True: + if assume_a == "pos" or getattr(A.tag, "psd", None) is True: L = cholesky(A) # N.B. this can be further reduced to cho_solve Op # if no other Op makes use of the L matrix @@ -251,77 +247,69 @@ def cholesky_ldotlt(fgraph, node): This utilizes a boolean `lower_triangular` or `upper_triangular` tag on matrices. """ A = node.inputs[0] - if not ( - A.owner is not None and (isinstance(A.owner.op, Dot) or (A.owner.op == _matmul)) - ): - return - - l, r = A.owner.inputs - - # cholesky(dot(L,L.T)) case - if ( - getattr(l.tag, "lower_triangular", False) - and is_matrix_transpose(r) - and r.owner.inputs[0] == l - ): - if node.op.core_op.lower: - return [l] - return [r] - - # cholesky(dot(U.T,U)) case - if ( - getattr(r.tag, "upper_triangular", False) - and is_matrix_transpose(l) - and l.owner.inputs[0] == r - ): - if node.op.core_op.lower: - return [l] - return [r] + lower = node.op.core_op.lower + + match A.owner_op_and_inputs: + case (Blockwise(Dot()) | Dot(), l, r): + lower_triangular = getattr(l.tag, "lower_triangular", False) + match (lower_triangular, r.owner_op_and_inputs): + # cholesky(dot(L,L.T)) case + case ( + True, + (DimShuffle(is_left_expanded_matrix_transpose=True), l_T), + ) if l_T == l: + return [l] if lower else [r] + + upper_triangular = getattr(r.tag, "upper_triangular", False) + match (upper_triangular, l.owner_op_and_inputs): + # cholesky(dot(U.T,U)) case + case ( + True, + (DimShuffle(is_left_expanded_matrix_transpose=True), r_T), + ) if r_T == r: + return [l] if lower else [r] @register_stabilize @register_specialize @node_rewriter([det]) -def local_det_chol(fgraph, node): +def det_of_cholesky(fgraph, node): """ If we have det(X) and there is already an L=cholesky(X) floating around, then we can use prod(diag(L)) to get the determinant. """ - (x,) = node.inputs - for cl, xpos in fgraph.clients[x]: - if isinstance(cl.op, Blockwise) and isinstance(cl.op.core_op, Cholesky): - L = cl.outputs[0] - return [prod(diagonal(L, axis1=-2, axis2=-1) ** 2, axis=-1)] + [x] = node.inputs + for cl, _ in fgraph.clients[x]: + match (cl.op, *cl.outputs): + case (Blockwise(Cholesky()), L): + return [prod(diagonal(L, axis1=-2, axis2=-1), axis=-1) ** 2] -@register_canonicalize @register_stabilize @register_specialize @node_rewriter([log]) -def local_log_prod_sqr(fgraph, node): +def log_of_prod_to_sum_of_log(fgraph, node): """ This utilizes a boolean `positive` tag on matrices. """ - (x,) = node.inputs - if x.owner and isinstance(x.owner.op, Prod): - # we cannot always make this substitution because - # the prod might include negative terms - p = x.owner.inputs[0] + [x] = node.inputs + match x.owner_op_and_inputs: + case (Prod(axis=axis), p): + # TODO: have a reduction like prod and sum that simply + # returns the sign of the prod multiplication. - # p is the matrix we're reducing with prod - if getattr(p.tag, "positive", None) is True: - return [log(p).sum(axis=x.owner.op.axis)] + match p.owner_op: + case Elemwise(Abs() | Sqr() | Exp()): + return [log(p).sum(axis=axis)] - # TODO: have a reduction like prod and sum that simply - # returns the sign of the prod multiplication. + if getattr(p.tag, "positive", False): + return [log(p).sum(axis=axis)] @register_specialize @node_rewriter([blockwise_of(MatrixInverse | Cholesky | MatrixPinv)]) -def local_lift_through_linalg( - fgraph: FunctionGraph, node: Apply -) -> list[Variable] | None: +def lift_linalg_of_expanded_matrices(fgraph: FunctionGraph, node: Apply): """ Rewrite compositions of linear algebra operations by lifting expensive operations (Cholesky, Inverse) through Ops that join matrices (KroneckerProduct, BlockDiagonal). @@ -346,106 +334,20 @@ def local_lift_through_linalg( """ # TODO: Simplify this if we end up Blockwising KroneckerProduct - y = node.inputs[0] outer_op = node.op + [y] = node.inputs - if y.owner and ( - ( - isinstance(y.owner.op, Blockwise) - and isinstance(y.owner.op.core_op, BlockDiagonal) - ) - or isinstance(y.owner.op, KroneckerProduct) - ): - input_matrices = y.owner.inputs - - if isinstance(outer_op.core_op, MatrixInverse): - outer_f = cast(Callable, inv) - elif isinstance(outer_op.core_op, Cholesky): - outer_f = cast(Callable, cholesky) - elif isinstance(outer_op.core_op, MatrixPinv): - outer_f = cast(Callable, pinv) - else: - raise NotImplementedError # pragma: no cover - - inner_matrices = [cast(TensorVariable, outer_f(m)) for m in input_matrices] - - if isinstance(y.owner.op, KroneckerProduct): - return [kron(*inner_matrices)] - elif isinstance(y.owner.op.core_op, BlockDiagonal): - return [block_diag(*inner_matrices)] - else: - raise NotImplementedError # pragma: no cover - return None - - -def _find_diag_from_eye_mul(potential_mul_input): - # Check if the op is Elemwise and mul - if not ( - potential_mul_input.owner is not None - and isinstance(potential_mul_input.owner.op, Elemwise) - and isinstance(potential_mul_input.owner.op.scalar_op, Mul) - ): - return None - - # Find whether any of the inputs to mul is Eye - inputs_to_mul = potential_mul_input.owner.inputs - eye_input = [ - mul_input - for mul_input in inputs_to_mul - if mul_input.owner - and ( - isinstance(mul_input.owner.op, Eye) - or - # This whole condition checks if there is an Eye hiding inside a DimShuffle. - # This arises from batched elementwise multiplication between a tensor and an eye, e.g.: - # tensor(shape=(None, 3, 3) * eye(3). This is still potentially valid for diag rewrites. - ( - isinstance(mul_input.owner.op, DimShuffle) - and ( - mul_input.owner.op.is_left_expand_dims - or mul_input.owner.op.is_right_expand_dims - ) - and mul_input.owner.inputs[0].owner is not None - and isinstance(mul_input.owner.inputs[0].owner.op, Eye) - ) - ) - ] - - if not eye_input: - return None - - eye_input = eye_input[0] - # If eye_input is an Eye Op (it's not wrapped in a DimShuffle), check it doesn't have an offset - if isinstance(eye_input.owner.op, Eye) and ( - not Eye.is_offset_zero(eye_input.owner) - or eye_input.broadcastable[-2:] != (False, False) - ): - return None - - # Otherwise, an Eye was found but it is wrapped in a DimShuffle (i.e. there was some broadcasting going on). - # We have to look inside DimShuffle to decide if the rewrite can be applied - if isinstance(eye_input.owner.op, DimShuffle) and ( - eye_input.owner.op.is_left_expand_dims - or eye_input.owner.op.is_right_expand_dims - ): - inner_eye = eye_input.owner.inputs[0] - # We can only rewrite when the Eye is on the main diagonal (the offset is zero) and the identity isn't - # degenerate - if not Eye.is_offset_zero(inner_eye.owner) or inner_eye.broadcastable[-2:] != ( - False, - False, - ): - return None - - # Get all non Eye inputs (scalars/matrices/vectors) - non_eye_inputs = list(set(inputs_to_mul) - {eye_input}) - return eye_input, non_eye_inputs + match y.owner_op_and_inputs: + case (Blockwise(BlockDiagonal()), *inner_matrices): + return [block_diag(*(outer_op(m) for m in inner_matrices))] + case (KroneckerProduct(), *inner_matrices): + return [kron(*(outer_op(m) for m in inner_matrices))] # type: ignore[unreachable] @register_canonicalize("shape_unsafe") @register_stabilize("shape_unsafe") @node_rewriter([det]) -def rewrite_det_diag_to_prod_diag(fgraph, node): +def det_of_diag(fgraph, node): """ This rewrite takes advantage of the fact that for a diagonal matrix, the determinant value is the product of its diagonal elements. @@ -467,43 +369,34 @@ def rewrite_det_diag_to_prod_diag(fgraph, node): list of Variable, optional List of optimized variables, or None if no optimization was performed """ - inputs = node.inputs[0] + inp = node.inputs[0] - # Check for use of pt.diag first - if ( - inputs.owner - and isinstance(inputs.owner.op, AllocDiag) - and AllocDiag.is_offset_zero(inputs.owner) - ): - diag_input = inputs.owner.inputs[0] - det_val = diag_input.prod(axis=-1) - return [det_val] + match inp.owner_op_and_inputs: + # Check for use of pt.diag first + case (AllocDiag(offset=0, axis1=axis1, axis2=axis2), diag_input): + ndim = diag_input.ndim + if axis1 == ndim - 1 and axis2 == ndim: + return [diag_input.prod(axis=-1)] # Check if the input is an elemwise multiply with identity matrix -- this also results in a diagonal matrix - inputs_or_none = _find_diag_from_eye_mul(inputs) - if inputs_or_none is None: - return None - - eye_input, non_eye_inputs = inputs_or_none - - # Dealing with only one other input - if len(non_eye_inputs) != 1: - return None - - eye_input, non_eye_input = eye_input[0], non_eye_inputs[0] - - # Checking if original x was scalar/vector/matrix - if non_eye_input.type.broadcastable[-2:] == (True, True): - # For scalar - det_val = non_eye_input.squeeze(axis=(-1, -2)) ** (eye_input.shape[0]) - elif non_eye_input.type.broadcastable[-2:] == (False, False): - # For Matrix - det_val = non_eye_input.diagonal(axis1=-1, axis2=-2).prod(axis=-1) - else: - # For vector - det_val = non_eye_input.prod(axis=(-1, -2)) - det_val = det_val.astype(node.outputs[0].type.dtype) - return [det_val] + match is_eye_mul(inp): + case (eye_term, non_eye_term): + # Checking if original x was scalar/vector/matrix + match non_eye_term.type.broadcastable[-2:]: + case (True, True): + # For scalar + det_val = ( + non_eye_term.squeeze(axis=(-1, -2)) ** (eye_term.shape[-1]) + ) + case (False, False): + # For Matrix + det_val = non_eye_term.diagonal(axis1=-1, axis2=-2).prod(axis=-1) + case _: + # For vector + det_val = non_eye_term.prod(axis=(-1, -2)) + + det_val = det_val.astype(node.outputs[0].type.dtype) + return [det_val] @register_canonicalize @@ -515,50 +408,51 @@ def svd_uv_merge(fgraph, node): `compute_uv=True`, then we can change `compute_uv = False` to `True` everywhere and allow `pytensor` to re-use the decomposition outputs instead of recomputing. """ - (x,) = node.inputs + [x] = node.inputs if node.op.core_op.compute_uv: # compute_uv=True returns [u, s, v]. + u, s, v = node.outputs + # if at least u or v is used, no need to rewrite this node. - if ( - len(fgraph.clients[node.outputs[0]]) > 0 - or len(fgraph.clients[node.outputs[2]]) > 0 - ): - return + if fgraph.clients[u] or fgraph.clients[v]: + return None # Else, has to replace the s of this node with s of an SVD Op that compute_uv=False. # First, iterate to see if there is an SVD Op that can be reused. for cl, _ in fgraph.clients[x]: - if isinstance(cl.op, Blockwise) and isinstance(cl.op.core_op, SVD): - if not cl.op.core_op.compute_uv: - return { - node.outputs[1]: cl.outputs[0], - } - - # If no SVD reusable, return a new one. - return { - node.outputs[1]: svd( - x, full_matrices=node.op.core_op.full_matrices, compute_uv=False - ), - } + if cl is node: + continue + match (cl.op, *cl.outputs): + case (Blockwise(SVD(compute_uv=False)), replacement_s): + break + else: + # If no SVD reusable, return a new one. + replacement_s = svd( + x, + full_matrices=node.op.core_op.full_matrices, + compute_uv=False, + ) + return {s: replacement_s} else: # compute_uv=False returns [s]. # We want rewrite if there is another one with compute_uv=True. # For this case, just reuse the `s` from the one with compute_uv=True. for cl, _ in fgraph.clients[x]: - if isinstance(cl.op, Blockwise) and isinstance(cl.op.core_op, SVD): - if cl.op.core_op.compute_uv and ( - len(fgraph.clients[cl.outputs[0]]) > 0 - or len(fgraph.clients[cl.outputs[2]]) > 0 + if cl is node: + continue + match (cl.op, *cl.outputs): + case (Blockwise(SVD(compute_uv=True)), u, s, v) if ( + fgraph.clients[u] or fgraph.clients[v] ): - return [cl.outputs[1]] + return [s] @register_canonicalize @register_stabilize @node_rewriter([blockwise_of(MATRIX_INVERSE_OPS)]) -def rewrite_inv_inv(fgraph, node): +def useless_consecutive_inv(fgraph, node): """ This rewrite takes advantage of the fact that if there are two consecutive inverse operations (inv(inv(input))), we get back our original input without having to compute inverse once. @@ -576,56 +470,16 @@ def rewrite_inv_inv(fgraph, node): list of Variable, optional List of optimized variables, or None if no optimization was performed """ - # Check if its a valid inverse operation (either inv/pinv) - # In case the outer operation is an inverse, it directly goes to the next step of finding inner operation - # If the outer operation is not a valid inverse, we do not apply this rewrite - potential_inner_inv = node.inputs[0].owner - if potential_inner_inv is None or potential_inner_inv.op is None: - return None - - # Check if inner op is blockwise and and possible inv - if not ( - potential_inner_inv - and isinstance(potential_inner_inv.op, Blockwise) - and isinstance(potential_inner_inv.op.core_op, MATRIX_INVERSE_OPS) - ): - return None - return [potential_inner_inv.inputs[0]] - - -@register_canonicalize -@register_stabilize -@node_rewriter([blockwise_of(MATRIX_INVERSE_OPS)]) -def rewrite_inv_eye_to_eye(fgraph, node): - """ - This rewrite takes advantage of the fact that the inverse of an identity matrix is the matrix itself - The presence of an identity matrix is identified by checking whether we have k = 0 for an Eye Op inside an inverse op. - Parameters - ---------- - fgraph: FunctionGraph - Function graph being optimized - node: Apply - Node of the function graph to be optimized - Returns - ------- - list of Variable, optional - List of optimized variables, or None if no optimization was performed - """ - # Check whether input to inverse is Eye and the 1's are on main diagonal - potential_eye = node.inputs[0] - if not ( - potential_eye.owner - and isinstance(potential_eye.owner.op, Eye) - and getattr(potential_eye.owner.inputs[-1], "data", -1).item() == 0 - ): - return None - return [potential_eye] + # Check if inner op is blockwise and possible inv + match node.inputs[0].owner_op_and_inputs: + case (Blockwise(MatrixInverse() | MatrixPinv()), X): + return [X] @register_canonicalize @register_stabilize @node_rewriter([blockwise_of(MATRIX_INVERSE_OPS)]) -def rewrite_inv_diag_to_diag_reciprocal(fgraph, node): +def inv_of_diag_to_diag_reciprocal(fgraph, node): """ This rewrite takes advantage of the fact that for a diagonal matrix, the inverse is a diagonal matrix with the new diagonal entries as reciprocals of the original diagonal elements. This function deals with diagonal matrix arising from the multiplicaton of eye with a scalar/vector/matrix @@ -642,42 +496,32 @@ def rewrite_inv_diag_to_diag_reciprocal(fgraph, node): list of Variable, optional List of optimized variables, or None if no optimization was performed """ - inputs = node.inputs[0] - # Check for use of pt.diag first - if ( - inputs.owner - and isinstance(inputs.owner.op, AllocDiag) - and AllocDiag.is_offset_zero(inputs.owner) - ): - inv_input = inputs.owner.inputs[0] - inv_val = pt.diag(1 / inv_input) - return [inv_val] + inp = node.inputs[0] - # Check if the input is an elemwise multiply with identity matrix -- this also results in a diagonal matrix - inputs_or_none = _find_diag_from_eye_mul(inputs) - if inputs_or_none is None: - return None - - eye_input, non_eye_inputs = inputs_or_none - - # Dealing with only one other input - if len(non_eye_inputs) != 1: - return None - - non_eye_input = non_eye_inputs[0] + # Check for diagonal constructors first + match inp.owner_op_and_inputs: + case (Eye(), _, _, Constant(0)): + return [inp] + case (AllocDiag(offset=0, axis1=axis1, axis2=axis2), inv_input): + ndim = inv_input.type.ndim + if axis1 == ndim - 1 and axis2 == ndim: + return [pt.diag(1 / inv_input)] - # For a matrix, we have to first extract the diagonal (non-zero values) and then only use those - if non_eye_input.type.broadcastable[-2:] == (False, False): - non_eye_diag = non_eye_input.diagonal(axis1=-1, axis2=-2) - non_eye_input = pt.shape_padaxis(non_eye_diag, -2) + # Check if the input is an elemwise multiply with identity matrix -- this also results in a diagonal matrix + match is_eye_mul(inp): + case (eye_term, non_eye_term): + # For a matrix, we have to first extract the diagonal (non-zero values) and then only use those + if non_eye_term.type.broadcastable[-2:] == (False, False): + non_eye_diag = non_eye_term.diagonal(axis1=-1, axis2=-2) + non_eye_term = pt.shape_padaxis(non_eye_diag, -2) - return [eye_input / non_eye_input] + return [eye_term / non_eye_term] @register_canonicalize @register_stabilize @node_rewriter([ExtractDiag]) -def rewrite_diag_blockdiag(fgraph, node): +def diag_of_blockdiag(fgraph, node): """ This rewrite simplifies extracting the diagonal of a blockdiagonal matrix by concatening the diagonal values of all of the individual sub matrices. @@ -696,25 +540,16 @@ def rewrite_diag_blockdiag(fgraph, node): List of optimized variables, or None if no optimization was performed """ # Check for inner block_diag operation - potential_block_diag = node.inputs[0].owner - if not ( - potential_block_diag - and isinstance(potential_block_diag.op, Blockwise) - and isinstance(potential_block_diag.op.core_op, BlockDiagonal) - ): - return None - - # Find the composing sub_matrices - submatrices = potential_block_diag.inputs - submatrices_diag = [diag(submatrices[i]) for i in range(len(submatrices))] - - return [concatenate(submatrices_diag)] + match node.inputs[0].owner_op_and_inputs: + case (Blockwise(BlockDiagonal()), *submatrices): + submatrices_diag = [diag(m) for m in submatrices] + return [concatenate(submatrices_diag, axis=-1)] @register_canonicalize @register_stabilize @node_rewriter([det]) -def rewrite_det_blockdiag(fgraph, node): +def det_of_blockdiag(fgraph, node): """ This rewrite simplifies the determinant of a blockdiagonal matrix by extracting the individual sub matrices and returning the product of all individual determinant values. @@ -733,25 +568,16 @@ def rewrite_det_blockdiag(fgraph, node): List of optimized variables, or None if no optimization was performed """ # Check for inner block_diag operation - potential_block_diag = node.inputs[0].owner - if not ( - potential_block_diag - and isinstance(potential_block_diag.op, Blockwise) - and isinstance(potential_block_diag.op.core_op, BlockDiagonal) - ): - return None - - # Find the composing sub_matrices - sub_matrices = potential_block_diag.inputs - det_sub_matrices = [det(sub_matrices[i]) for i in range(len(sub_matrices))] - - return [prod(det_sub_matrices)] + match node.inputs[0].owner_op_and_inputs: + case (Blockwise(BlockDiagonal()), *sub_matrices): + det_sub_matrices = [det(m) for m in sub_matrices] + return [prod(det_sub_matrices, axis=-1)] @register_canonicalize @register_stabilize @node_rewriter([ExtractDiag]) -def rewrite_diag_kronecker(fgraph, node): +def diag_of_kronecker(fgraph, node): """ This rewrite simplifies the diagonal of the kronecker product of 2 matrices by extracting the individual sub matrices and returning their outer product as a vector. @@ -770,22 +596,17 @@ def rewrite_diag_kronecker(fgraph, node): List of optimized variables, or None if no optimization was performed """ # Check for inner kron operation - potential_kron = node.inputs[0].owner - if not (potential_kron and isinstance(potential_kron.op, KroneckerProduct)): - return None - - # Find the matrices - a, b = potential_kron.inputs - diag_a, diag_b = diag(a), diag(b) - outer_prod_as_vector = outer(diag_a, diag_b).flatten() - - return [outer_prod_as_vector] + match node.inputs[0].owner_op_and_inputs: + case (KroneckerProduct(), a, b): + diag_a, diag_b = diag(a), diag(b) + outer_prod_as_vector = outer(diag_a, diag_b).flatten() + return [outer_prod_as_vector] @register_canonicalize @register_stabilize @node_rewriter([det]) -def rewrite_det_kronecker(fgraph, node): +def det_of_kronecker(fgraph, node): """ This rewrite simplifies the determinant of a kronecker-structured matrix by extracting the individual sub matrices and returning the det values computed using those @@ -802,96 +623,47 @@ def rewrite_det_kronecker(fgraph, node): List of optimized variables, or None if no optimization was performed """ # Check for inner kron operation - potential_kron = node.inputs[0].owner - if not (potential_kron and isinstance(potential_kron.op, KroneckerProduct)): - return None - - # Find the matrices - a, b = potential_kron.inputs - dets = [det(a), det(b)] - sizes = [a.shape[-1], b.shape[-1]] - prod_sizes = prod(sizes, no_zeros_in_input=True) - det_final = prod([dets[i] ** (prod_sizes / sizes[i]) for i in range(2)]) - - return [det_final] - - -@register_canonicalize -@register_stabilize -@node_rewriter([blockwise_of(Cholesky)]) -def rewrite_remove_useless_cholesky(fgraph, node): - """ - This rewrite takes advantage of the fact that the cholesky decomposition of an identity matrix is the matrix itself - - The presence of an identity matrix is identified by checking whether we have k = 0 for an Eye Op inside Cholesky. - - Parameters - ---------- - fgraph: FunctionGraph - Function graph being optimized - node: Apply - Node of the function graph to be optimized - - Returns - ------- - list of Variable, optional - List of optimized variables, or None if no optimization was performed - """ - # Find whether cholesky op is being applied - - # Check whether input to Cholesky is Eye and the 1's are on main diagonal - potential_eye = node.inputs[0] - if not ( - potential_eye.owner - and isinstance(potential_eye.owner.op, Eye) - and hasattr(potential_eye.owner.inputs[-1], "data") - and potential_eye.owner.inputs[-1].data.item() == 0 - ): - return None - return [potential_eye] + match node.inputs[0].owner_op_and_inputs: + case (KroneckerProduct(), a, b): + dets = [det(a), det(b)] + sizes = [a.shape[-1], b.shape[-1]] + prod_sizes = prod(sizes, no_zeros_in_input=True) + det_final = prod( + [dets[i] ** (prod_sizes / sizes[i]) for i in range(2)], axis=-1 + ) + return [det_final] @register_canonicalize @register_stabilize @node_rewriter([blockwise_of(Cholesky)]) -def rewrite_cholesky_diag_to_sqrt_diag(fgraph, node): - [input] = node.inputs +def cholesky_of_diag(fgraph, node): + [X] = node.inputs # Check if input is a (1, 1) matrix - if all(input.type.broadcastable[-2:]): - return [pt.sqrt(input)] - - # Check for use of pt.diag first - if ( - input.owner - and isinstance(input.owner.op, AllocDiag) - and AllocDiag.is_offset_zero(input.owner) - ): - diag_input = input.owner.inputs[0] - cholesky_val = pt.diag(diag_input**0.5) - return [cholesky_val] + if all(X.type.broadcastable[-2:]): + return [pt.sqrt(X)] + + match X.owner_op_and_inputs: + # Check whether input to Cholesky is Eye and the 1's are on main diagonal + case (Eye(), _, _, Constant(0)): + return [X] + case (AllocDiag(offset=0, axis1=axis1, axis2=axis2), diag_input): + ndim = diag_input.ndim + if axis1 == ndim - 1 and axis2 == ndim: + return [pt.diag(diag_input**0.5)] # Check if the input is an elemwise multiply with identity matrix -- this also results in a diagonal matrix - inputs_or_none = _find_diag_from_eye_mul(input) - if inputs_or_none is None: - return None - - eye_input, non_eye_inputs = inputs_or_none - - # Dealing with only one other input - if len(non_eye_inputs) != 1: - return None - - [non_eye_input] = non_eye_inputs - - # Now, we can simply return the matrix consisting of sqrt values of the original diagonal elements - # For a matrix, we have to first extract the diagonal (non-zero values) and then only use those - if non_eye_input.type.broadcastable[-2:] == (False, False): - non_eye_input = non_eye_input.diagonal(axis1=-1, axis2=-2) - if eye_input.type.ndim > 2: - non_eye_input = pt.shape_padaxis(non_eye_input, -2) + match is_eye_mul(X): + case (eye_input, non_eye_input): + # Now, we can simply return the matrix consisting of sqrt values of the original diagonal elements + # For a matrix, we have to first extract the diagonal (non-zero values) and then only use those + if non_eye_input.type.broadcastable[-2:] == (False, False): + non_eye_input = non_eye_input.diagonal(axis1=-1, axis2=-2) + if eye_input.type.ndim > 2: + non_eye_input = pt.shape_padaxis(non_eye_input, -2) - return [eye_input * (non_eye_input**0.5)] + return [eye_input * (non_eye_input**0.5)] @node_rewriter([_bilinear_solve_discrete_lyapunov]) @@ -899,9 +671,8 @@ def jax_bilinaer_lyapunov_to_direct(fgraph: FunctionGraph, node: Apply): """ Replace BilinearSolveDiscreteLyapunov with a direct computation that is supported by JAX """ - A, B = (cast(TensorVariable, x) for x in node.inputs) + A, B = node.inputs result = solve_discrete_lyapunov(A, B, method="direct") - return [result] @@ -933,46 +704,39 @@ def slogdet_specialization(fgraph, node): """ dummy_replacements = {} for client, _ in fgraph.clients[node.outputs[0]]: - # Check for sign(det) - if isinstance(client.op, Elemwise) and isinstance(client.op.scalar_op, Sign): - dummy_replacements[client.outputs[0]] = "sign" - - # Check for log(abs(det)) - elif isinstance(client.op, Elemwise) and isinstance(client.op.scalar_op, Abs): - potential_log = None - for client_2, _ in fgraph.clients[client.outputs[0]]: - if isinstance(client_2.op, Elemwise) and isinstance( - client_2.op.scalar_op, Log - ): - potential_log = client_2 - if potential_log: - dummy_replacements[potential_log.outputs[0]] = "log_abs_det" - else: + match (client.op, *client.outputs): + # Check for sign(det) + case (Elemwise(Sign()), sign): + dummy_replacements[sign] = "sign" + + # Check for log(abs(det)) + case (Elemwise(Abs()), potential_log): + for client_2, _ in fgraph.clients[potential_log]: + match (client_2.op, *client_2.outputs): + case (Elemwise(Log()), log_abs_det): + dummy_replacements[log_abs_det] = "log_abs_det" + case _: + return None + + case (Elemwise(Log()), log_det): + dummy_replacements[log_det] = "log_det" + + case _: + # Det is used directly for something else, don't rewrite to avoid computing two dets return None - # Check for log(det) - elif isinstance(client.op, Elemwise) and isinstance(client.op.scalar_op, Log): - dummy_replacements[client.outputs[0]] = "log_det" - - # Det is used directly for something else, don't rewrite to avoid computing two dets - else: - return None - if not dummy_replacements: return None - else: - [x] = node.inputs - sign_det_x, log_abs_det_x = SLogDet()(x) - log_det_x = pt.where(pt.eq(sign_det_x, -1), np.nan, log_abs_det_x) - slogdet_specialization_map = { - "sign": sign_det_x, - "log_abs_det": log_abs_det_x, - "log_det": log_det_x, - } - replacements = { - k: slogdet_specialization_map[v] for k, v in dummy_replacements.items() - } - return replacements + + [x] = node.inputs + sign_det_x, log_abs_det_x = SLogDet()(x) + log_det_x = pt.where(pt.eq(sign_det_x, -1), np.nan, log_abs_det_x) + slogdet_specialization_map = { + "sign": sign_det_x, + "log_abs_det": log_abs_det_x, + "log_det": log_det_x, + } + return {k: slogdet_specialization_map[v] for k, v in dummy_replacements.items()} @register_stabilize @@ -982,16 +746,13 @@ def scalar_solve_to_division(fgraph, node): """ Replace solve(a, b) with b / a if a is a (1, 1) matrix """ - - core_op = node.op.core_op - if not isinstance(core_op, SolveBase): - return None - a, b = node.inputs old_out = node.outputs[0] if not all(a.broadcastable[-2:]): return None + core_op = node.op.core_op + if core_op.b_ndim == 1: # Convert b to a column matrix b = b[..., None] diff --git a/pytensor/tensor/rewriting/math.py b/pytensor/tensor/rewriting/math.py index 1530a0ed90..abf0015824 100644 --- a/pytensor/tensor/rewriting/math.py +++ b/pytensor/tensor/rewriting/math.py @@ -105,7 +105,6 @@ ) from pytensor.tensor.rewriting.blockwise import blockwise_of from pytensor.tensor.rewriting.elemwise import apply_local_dimshuffle_lift -from pytensor.tensor.rewriting.linalg import is_matrix_transpose from pytensor.tensor.shape import Shape, Shape_i, specify_shape from pytensor.tensor.slinalg import BlockDiagonal from pytensor.tensor.subtensor import Subtensor @@ -235,7 +234,10 @@ def local_lift_transpose_through_dot(fgraph, node): [(client, _)] = clients - if not (isinstance(client.op, DimShuffle) and is_matrix_transpose(client.out)): + if not ( + isinstance(client.op, DimShuffle) + and client.op.is_left_expanded_matrix_transpose + ): return None x, y = node.inputs diff --git a/tests/tensor/linalg/test_rewriting.py b/tests/tensor/linalg/test_rewriting.py index 2e2b11257d..cf0b981de2 100644 --- a/tests/tensor/linalg/test_rewriting.py +++ b/tests/tensor/linalg/test_rewriting.py @@ -86,13 +86,21 @@ def test_lu_decomposition_reused_forward_and_gradient(assume_a, counter, transpo x = solve(A, b, assume_a=assume_a, transposed=transposed) grad_x_wrt_A = grad(x.sum(), A) - fn_no_opt = function([A, b], [x, grad_x_wrt_A], mode=mode.excluding(rewrite_name)) + fn_no_opt = function( + [A, b], + [x, grad_x_wrt_A], + mode=mode.excluding(rewrite_name, "psd_solve_to_chol_solve"), + ) no_opt_nodes = fn_no_opt.maker.fgraph.apply_nodes assert counter.count_vanilla_solve_nodes(no_opt_nodes) == 2 assert counter.count_decomp_nodes(no_opt_nodes) == 0 assert counter.count_solve_nodes(no_opt_nodes) == 0 - fn_opt = function([A, b], [x, grad_x_wrt_A], mode=mode.including(rewrite_name)) + fn_opt = function( + [A, b], + [x, grad_x_wrt_A], + mode=mode.including(rewrite_name).excluding("psd_solve_to_chol_solve"), + ) opt_nodes = fn_opt.maker.fgraph.apply_nodes assert counter.count_vanilla_solve_nodes(opt_nodes) == 0 assert counter.count_decomp_nodes(opt_nodes) == 1 @@ -129,13 +137,19 @@ def test_lu_decomposition_reused_blockwise(assume_a, counter, transposed): b = tensor("b", shape=(2, 3, 4)) x = solve(A, b, assume_a=assume_a, transposed=transposed) - fn_no_opt = function([A, b], [x], mode=mode.excluding(rewrite_name)) + fn_no_opt = function( + [A, b], [x], mode=mode.excluding(rewrite_name, "psd_solve_to_chol_solve") + ) no_opt_nodes = fn_no_opt.maker.fgraph.apply_nodes assert counter.count_vanilla_solve_nodes(no_opt_nodes) == 1 assert counter.count_decomp_nodes(no_opt_nodes) == 0 assert counter.count_solve_nodes(no_opt_nodes) == 0 - fn_opt = function([A, b], [x], mode=mode.including(rewrite_name)) + fn_opt = function( + [A, b], + [x], + mode=mode.including(rewrite_name).excluding("psd_solve_to_chol_solve"), + ) opt_nodes = fn_opt.maker.fgraph.apply_nodes assert counter.count_vanilla_solve_nodes(opt_nodes) == 0 assert counter.count_decomp_nodes(opt_nodes) == 1 @@ -176,6 +190,7 @@ def test_lu_decomposition_reused_scan(assume_a, counter, transposed): non_sequences=[A], n_steps=10, return_updates=False, + mode=get_default_mode().excluding("psd_solve_to_chol_solve"), ) fn_no_opt = function( diff --git a/tests/tensor/rewriting/test_linalg.py b/tests/tensor/rewriting/test_linalg.py index 7828acf4ca..ce5ddc376e 100644 --- a/tests/tensor/rewriting/test_linalg.py +++ b/tests/tensor/rewriting/test_linalg.py @@ -23,10 +23,12 @@ MatrixInverse, MatrixPinv, SLogDet, + inv, matrix_inverse, + pinv, svd, ) -from pytensor.tensor.rewriting.linalg import inv_as_solve +from pytensor.tensor.rewriting.linalg import inv_to_solve from pytensor.tensor.slinalg import ( BlockDiagonal, Cholesky, @@ -42,7 +44,7 @@ from pytensor.tensor.type import dmatrix, matrix, tensor -def test_transinv_to_invtrans(): +def test_transpose_of_inv(): X = matrix("X") Y = matrix_inverse(X) Z = Y.transpose() @@ -96,11 +98,11 @@ def test_generic_solve_to_solve_triangular(): ) -def test_matrix_inverse_solve(): +def test_inv_to_solve(): A = dmatrix("A") b = dmatrix("b") node = matrix_inverse(A).dot(b).owner - [out] = inv_as_solve.transform(None, node) + [out] = inv_to_solve.transform(None, node) assert isinstance(out.owner.op, Blockwise) and isinstance( out.owner.op.core_op, Solve ) @@ -187,7 +189,7 @@ def test_cholesky_ldotlt(tag, cholesky_form, product, op): ) -def test_local_det_chol(): +def test_det_of_cholesky(): X = matrix("X") L = pt.linalg.cholesky(X) det_X = pt.linalg.det(X) @@ -314,7 +316,7 @@ def test_invalid_batched_a(self): [(BlockDiagonal, pt.linalg.block_diag), (KroneckerProduct, pt.linalg.kron)], ids=["block_diag", "kron"], ) -def test_local_lift_through_linalg(constructor, f_op, f, g_op, g): +def test_lift_linalg_of_expanded_matrices(constructor, f_op, f, g_op, g): rng = np.random.default_rng(sum(map(ord, "lift_through_linalg"))) if pytensor.config.floatX.endswith("32"): @@ -353,7 +355,7 @@ def test_local_lift_through_linalg(constructor, f_op, f, g_op, g): [(), (7,), (1, 7), (7, 1), (7, 7), (3, 7, 7)], ids=["scalar", "vector", "row_vec", "col_vec", "matrix", "batched_input"], ) -def test_det_diag_from_eye_mul(shape): +def test_det_of_diag_from_eye_mul(shape): # Initializing x based on scalar/vector/matrix x = pt.tensor("x", shape=shape) y = pt.eye(7) * x @@ -390,7 +392,7 @@ def test_det_diag_from_eye_mul(shape): ) -def test_det_diag_from_diag(): +def test_det_of_diag_from_diag(): x = pt.tensor("x", shape=(None,)) x_diag = pt.diag(x) y = pt.linalg.det(x_diag) @@ -414,12 +416,12 @@ def test_det_diag_from_diag(): ) -def test_dont_apply_det_diag_rewrite_for_1_1(): +def test_dont_apply_det_of_diag_from_scalar_eye(): x = pt.matrix("x") x_diag = pt.eye(1, 1) * x y = pt.linalg.det(x_diag) f_rewritten = function([x], y, mode="FAST_RUN") - + f_rewritten.dprint() nodes = f_rewritten.maker.fgraph.apply_nodes assert any(isinstance(node.op, Det) for node in nodes) @@ -438,7 +440,7 @@ def test_dont_apply_det_diag_rewrite_for_1_1(): ) -def test_det_diag_incorrect_for_rectangle_eye(): +def test_det_of_diag_incorrect_for_rectangle_eye(): x = pt.matrix("x") x_diag = pt.eye(7, 5) * x with pytest.raises(ValueError, match="Determinant not defined"): @@ -509,24 +511,20 @@ def test_svd_uv_merge(): assert svd_counter == 1 -def get_pt_function(x, op_name): - return getattr(pt.linalg, op_name)(x) - - -@pytest.mark.parametrize("inv_op_1", ["inv", "pinv"]) -@pytest.mark.parametrize("inv_op_2", ["inv", "pinv"]) -def test_inv_inv_rewrite(inv_op_1, inv_op_2): +@pytest.mark.parametrize("inv_op_1", [inv, pinv]) +@pytest.mark.parametrize("inv_op_2", [inv, pinv]) +def test_useless_consecutive_inv(inv_op_1, inv_op_2): x = pt.matrix("x") - op1 = get_pt_function(x, inv_op_1) - op2 = get_pt_function(op1, inv_op_2) - rewritten_out = rewrite_graph(op2) + inv_x = inv_op_1(x) + x_again = inv_op_2(inv_x) + rewritten_out = rewrite_graph(x_again) assert rewritten_out == x -@pytest.mark.parametrize("inv_op", ["inv", "pinv"]) -def test_inv_eye_to_eye(inv_op): +@pytest.mark.parametrize("inv_op", [inv, pinv]) +def test_inv_of_diag_from_eye(inv_op): x = pt.eye(10) - x_inv = get_pt_function(x, inv_op) + x_inv = inv_op(x) f_rewritten = function([], x_inv, mode="FAST_RUN") nodes = f_rewritten.maker.fgraph.apply_nodes @@ -552,13 +550,13 @@ def test_inv_eye_to_eye(inv_op): [(), (7,), (7, 7), (5, 7, 7)], ids=["scalar", "vector", "matrix", "batched"], ) -@pytest.mark.parametrize("inv_op", ["inv", "pinv"]) -def test_inv_diag_from_eye_mul(shape, inv_op): +@pytest.mark.parametrize("inv_op", [inv, pinv]) +def test_inv_of_diag_from_eye_mul(shape, inv_op): # Initializing x based on scalar/vector/matrix x = pt.tensor("x", shape=shape) x_diag = pt.eye(7) * x # Calculating inverse using pt.linalg.inv - x_inv = get_pt_function(x_diag, inv_op) + x_inv = inv_op(x_diag) # REWRITE TEST f_rewritten = function([x], x_inv, mode="FAST_RUN") @@ -587,11 +585,11 @@ def test_inv_diag_from_eye_mul(shape, inv_op): ) -@pytest.mark.parametrize("inv_op", ["inv", "pinv"]) -def test_inv_diag_from_diag(inv_op): +@pytest.mark.parametrize("inv_op", [inv, pinv]) +def test_inv_of_diag_to_diag_reciprocal(inv_op): x = pt.dvector("x") x_diag = pt.diag(x) - x_inv = get_pt_function(x_diag, inv_op) + x_inv = inv_op(x_diag) # REWRITE TEST f_rewritten = function([x], x_inv, mode="FAST_RUN") @@ -615,7 +613,7 @@ def test_inv_diag_from_diag(inv_op): ) -def test_diag_blockdiag_rewrite(): +def test_diag_of_blockdiag(): n_matrices = 10 matrix_size = (5, 5) sub_matrices = pt.tensor("sub_matrices", shape=(n_matrices, *matrix_size)) @@ -642,7 +640,7 @@ def test_diag_blockdiag_rewrite(): ) -def test_det_blockdiag_rewrite(): +def test_det_of_blockdiag_(): n_matrices = 100 matrix_size = (5, 5) sub_matrices = pt.tensor("sub_matrices", shape=(n_matrices, *matrix_size)) @@ -669,7 +667,7 @@ def test_det_blockdiag_rewrite(): ) -def test_slogdet_blockdiag_rewrite(): +def test_slogdet_of_blockdiag(): n_matrices = 10 matrix_size = (5, 5) sub_matrices = pt.tensor("sub_matrices", shape=(n_matrices, *matrix_size)) @@ -704,7 +702,7 @@ def test_slogdet_blockdiag_rewrite(): ) -def test_diag_kronecker_rewrite(): +def test_diag_of_kronecker(): a, b = pt.dmatrices("a", "b") kron_prod = pt.linalg.kron(a, b) diag_kron_prod = pt.diag(kron_prod) @@ -727,7 +725,7 @@ def test_diag_kronecker_rewrite(): ) -def test_det_kronecker_rewrite(): +def test_det_of_kronecker(): a, b = pt.dmatrices("a", "b") kron_prod = pt.linalg.kron(a, b) det_output = pt.linalg.det(kron_prod) @@ -837,7 +835,7 @@ def test_cholesky_diag_from_eye_mul(shape): ) -def test_cholesky_diag_from_diag(): +def test_cholesky_of_diag(): x = pt.dvector("x") x_diag = pt.diag(x) x_cholesky = pt.linalg.cholesky(x_diag) @@ -862,7 +860,7 @@ def test_cholesky_diag_from_diag(): ) -def test_rewrite_cholesky_diag_to_sqrt_diag_not_applied(): +def test_cholesky_of_diag_not_applied(): # Case 1 : y is not a diagonal matrix because of k = -1 x = pt.tensor("x", shape=(7, 7)) y = pt.eye(7, k=-1) * x @@ -956,7 +954,7 @@ def test_slogdet_specialization(): (CholeskySolve, pt.linalg.cho_solve, {}), ], ) -def test_scalar_solve_to_division_rewrite( +def test_scalar_solve_to_division( op, fn, extra_kwargs, b_ndim, a_batch_shape, b_batch_shape ): def solve_op_in_graph(graph):