Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
38 changes: 18 additions & 20 deletions pytensor/graph/basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -553,31 +553,27 @@ def clone(self, **kwargs):
cp.tag = copy(self.tag)
return cp

def __lt__(self, other):
raise NotImplementedError(
"Subclasses of Variable must provide __lt__", self.__class__.__name__
)

def __le__(self, other):
raise NotImplementedError(
"Subclasses of Variable must provide __le__", self.__class__.__name__
)

def __gt__(self, other):
raise NotImplementedError(
"Subclasses of Variable must provide __gt__", self.__class__.__name__
)

def __ge__(self, other):
raise NotImplementedError(
"Subclasses of Variable must provide __ge__", self.__class__.__name__
)

def get_parents(self):
if self.owner is not None:
return [self.owner]
return []

@property
def owner_op(self) -> Optional["Op"]:
if (apply := self.owner) is not None:
return apply.op # type: ignore[no-any-return]
else:
return None

@property
def owner_op_and_inputs(
self,
) -> tuple[Optional["Op"], *tuple["Variable", ...]]:
if (apply := self.owner) is not None:
return apply.op, *apply.inputs # type: ignore[has-type]
else:
return (None,)

def eval(
self,
inputs_to_values: dict[Union["Variable", str], Any] | None = None,
Expand Down Expand Up @@ -793,6 +789,8 @@ class Constant(AtomicVariable[_TypeType]):
"""

# __slots__ = ['data']
# Allow pattern matching on data field positionally
__match_args__ = ("data",)

def __init__(self, type: _TypeType, data: Any, name: str | None = None):
super().__init__(type, name=name)
Expand Down
2 changes: 1 addition & 1 deletion pytensor/link/jax/dispatch/elemwise.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@ def careduce(x):
@jax_funcify.register(DimShuffle)
def jax_funcify_DimShuffle(op, **kwargs):
def dimshuffle(x):
res = jnp.transpose(x, op.transposition)
res = jnp.transpose(x, op._transposition)

shape = list(res.shape[: len(op.shuffle)])

Expand Down
2 changes: 1 addition & 1 deletion pytensor/link/mlx/dispatch/elemwise.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ def dimshuffle(x):
isinstance(x, np.number) and not isinstance(x, np.ndarray)
):
x = mx.array(x)
res = mx.transpose(x, op.transposition)
res = mx.transpose(x, op._transposition)
shape = list(res.shape[: len(op.shuffle)])
for augm in op.augment:
shape.insert(augm, 1)
Expand Down
2 changes: 1 addition & 1 deletion pytensor/link/pytorch/dispatch/elemwise.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ def elemwise_fn(*inputs):
@pytorch_funcify.register(DimShuffle)
def pytorch_funcify_DimShuffle(op, **kwargs):
def dimshuffle(x):
res = torch.permute(x, op.transposition)
res = torch.permute(x, op._transposition)

shape = list(res.shape[: len(op.shuffle)])

Expand Down
97 changes: 47 additions & 50 deletions pytensor/tensor/_linalg/solve/rewriting.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@
from pytensor.tensor.elemwise import DimShuffle
from pytensor.tensor.rewriting.basic import register_specialize
from pytensor.tensor.rewriting.blockwise import blockwise_of
from pytensor.tensor.rewriting.linalg import is_matrix_transpose
from pytensor.tensor.slinalg import Solve, cho_solve, cholesky, lu_factor, lu_solve
from pytensor.tensor.variable import TensorVariable

Expand Down Expand Up @@ -79,28 +78,26 @@ def get_root_A(a: TensorVariable) -> tuple[TensorVariable, bool]:
# the root variable is the pre-DimShuffled input.
# Otherwise, `a` is considered the root variable.
# We also return whether the root `a` is transposed.
root_a = a
transposed = False
if a.owner is not None and isinstance(a.owner.op, DimShuffle):
if a.owner.op.is_left_expand_dims:
[a] = a.owner.inputs
elif is_matrix_transpose(a):
[a] = a.owner.inputs
transposed = True
return a, transposed
match a.owner_op_and_inputs:
case (DimShuffle(is_left_expand_dims=True), root_a): # type: ignore[misc]
transposed = False
case (DimShuffle(is_left_expanded_matrix_transpose=True), root_a): # type: ignore[misc]
transposed = True # type: ignore[unreachable]

return root_a, transposed

def find_solve_clients(var, assume_a):
clients = []
for cl, idx in fgraph.clients[var]:
if (
idx == 0
and isinstance(cl.op, Blockwise)
and isinstance(cl.op.core_op, Solve)
and (cl.op.core_op.assume_a == assume_a)
):
clients.append(cl)
elif isinstance(cl.op, DimShuffle) and cl.op.is_left_expand_dims:
# If it's a left expand_dims, recurse on the output
clients.extend(find_solve_clients(cl.outputs[0], assume_a))
match (idx, cl.op, *cl.outputs):
case (0, Blockwise(Solve(assume_a=assume_a_var)), *_) if (
assume_a_var == assume_a
):
clients.append(cl)
case (0, DimShuffle(is_left_expand_dims=True), cl_out):
clients.extend(find_solve_clients(cl_out, assume_a))
return clients

assume_a = node.op.core_op.assume_a
Expand All @@ -119,11 +116,11 @@ def find_solve_clients(var, assume_a):

# Find Solves using A.T
for cl, _ in fgraph.clients[A]:
if isinstance(cl.op, DimShuffle) and is_matrix_transpose(cl.out):
A_T = cl.out
A_solve_clients_and_transpose.extend(
(client, True) for client in find_solve_clients(A_T, assume_a)
)
match (cl.op, *cl.outputs):
case (DimShuffle(is_left_expanded_matrix_transpose=True), A_T):
A_solve_clients_and_transpose.extend(
(client, True) for client in find_solve_clients(A_T, assume_a)
)

if not eager and len(A_solve_clients_and_transpose) == 1:
# If theres' a single use don't do it... unless it's being broadcast in a Blockwise (or we're eager)
Expand Down Expand Up @@ -185,34 +182,34 @@ def _scan_split_non_sequence_decomposition_and_solve(
changed = False
while True:
for inner_node in new_scan_fgraph.toposort():
if (
isinstance(inner_node.op, Blockwise)
and isinstance(inner_node.op.core_op, Solve)
and inner_node.op.core_op.assume_a in allowed_assume_a
):
A, _b = inner_node.inputs
if all(
(isinstance(root_inp, Constant) or (root_inp in non_sequences))
for root_inp in graph_inputs([A])
match (inner_node.op, *inner_node.inputs):
case (Blockwise(Solve(assume_a=assume_a_var)), A, _b) if (
assume_a_var in allowed_assume_a
):
if new_scan_fgraph is scan_op.fgraph:
# Clone the first time to avoid mutating the original fgraph
new_scan_fgraph, equiv = new_scan_fgraph.clone_get_equiv()
non_sequences = {equiv[non_seq] for non_seq in non_sequences}
inner_node = equiv[inner_node] # type: ignore

replace_dict = _split_decomp_and_solve_steps(
new_scan_fgraph,
inner_node,
eager=True,
allowed_assume_a=allowed_assume_a,
)
assert isinstance(replace_dict, dict) and len(replace_dict) > 0, (
"Rewrite failed"
)
new_scan_fgraph.replace_all(replace_dict.items())
changed = True
break # Break to start over with a fresh toposort
if all(
(isinstance(root_inp, Constant) or (root_inp in non_sequences))
for root_inp in graph_inputs([A])
):
if new_scan_fgraph is scan_op.fgraph:
# Clone the first time to avoid mutating the original fgraph
new_scan_fgraph, equiv = new_scan_fgraph.clone_get_equiv()
non_sequences = {
equiv[non_seq] for non_seq in non_sequences
}
inner_node = equiv[inner_node] # type: ignore

replace_dict = _split_decomp_and_solve_steps(
new_scan_fgraph,
inner_node,
eager=True,
allowed_assume_a=allowed_assume_a,
)
assert (
isinstance(replace_dict, dict) and len(replace_dict) > 0
), "Rewrite failed"
new_scan_fgraph.replace_all(replace_dict.items())
changed = True
break # Break to start over with a fresh toposort
else: # no_break
break # Nothing else changed

Expand Down
2 changes: 2 additions & 0 deletions pytensor/tensor/blockwise.py
Original file line number Diff line number Diff line change
Expand Up @@ -160,6 +160,8 @@ class Blockwise(COp):
"""

__props__ = ("core_op", "signature")
# Allow pattern matching on core_op positionally
__match_args__ = ("core_op",)

def __init__(
self,
Expand Down
56 changes: 35 additions & 21 deletions pytensor/tensor/elemwise.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,7 +132,7 @@ def __init__(self, *, input_ndim: int, new_order: Sequence[int | Literal["x"]]):
raise TypeError(f"input_ndim must be an integer, got {type(int)}")

self.input_ndim = input_ndim
self.new_order = tuple(new_order)
self.new_order = new_order = tuple(new_order)
self._new_order = [(-1 if x == "x" else x) for x in self.new_order]

for i, j in enumerate(new_order):
Expand All @@ -153,28 +153,38 @@ def __init__(self, *, input_ndim: int, new_order: Sequence[int | Literal["x"]]):
f"twice in the list of output dimensions: {new_order}"
)

# List of input dimensions to drop
drop = [i for i in range(input_ndim) if i not in new_order]

# This is the list of the original dimensions that we keep
self.shuffle = [x for x in new_order if x != "x"]
self.transposition = self.shuffle + drop
# List of dimensions of the output that are broadcastable and were not
# in the original input
self.augment = augment = sorted(i for i, x in enumerate(new_order) if x == "x")
self.drop = drop
# Tuple of the original dimensions that we keep
self.shuffle = tuple(x for x in new_order if x != "x")
# Tuple of input dimensions to drop
self.drop = drop = tuple(i for i in range(input_ndim) if i not in new_order)
# tuple of dimensions of the output that are broadcastable and were not in the original input
self.augment = augment = tuple(
sorted(i for i, x in enumerate(new_order) if x == "x")
)
n_augment = len(self.augment)

dims_are_shuffled = sorted(self.shuffle) != self.shuffle
# Used by perform
self._transposition = self.shuffle + drop

self.is_transpose = dims_are_shuffled and not augment and not drop
self.is_squeeze = drop and not dims_are_shuffled and not augment
self.is_expand_dims = augment and not dims_are_shuffled and not drop
self.is_left_expand_dims = self.is_expand_dims and (
input_ndim == 0 or new_order[-input_ndim:] == list(range(input_ndim))
# Classify the type of dimshuffle for rewrite purposes
dims_are_shuffled = tuple(sorted(self.shuffle)) != self.shuffle
self.is_squeeze = drop and not augment and not dims_are_shuffled
self.is_expand_dims = is_expand_dims = (
not drop and augment and not dims_are_shuffled
)
self.is_left_expand_dims = is_expand_dims and new_order[n_augment:] == tuple(
range(input_ndim)
)
self.is_right_expand_dims = is_expand_dims and new_order[:input_ndim] == tuple(
range(input_ndim)
)
self.is_transpose = not drop and not augment and dims_are_shuffled
self.is_left_expanded_matrix_transpose = is_left_expanded_matrix_transpose = (
dims_are_shuffled
and new_order[n_augment:]
== (*range(input_ndim - 2), input_ndim - 1, input_ndim - 2)
)
self.is_right_expand_dims = self.is_expand_dims and new_order[
:input_ndim
] == list(range(input_ndim))
self.is_matrix_transpose = not augment and is_left_expanded_matrix_transpose

def __setstate__(self, state):
self.__dict__.update(state)
Expand Down Expand Up @@ -212,6 +222,8 @@ def make_node(self, inp):
return Apply(self, [input], [output])

def __str__(self):
if self.is_matrix_transpose:
return "MatrixTranspose"
if self.is_expand_dims:
if len(self.augment) == 1:
return f"ExpandDims{{axis={self.augment[0]}}}"
Expand All @@ -237,7 +249,7 @@ def perform(self, node, inp, out):
# )

# Put dropped axis at end
res = res.transpose(self.transposition)
res = res.transpose(self._transposition)

# Define new shape without dropped axis and including new ones
new_shape = list(res.shape[: len(self.shuffle)])
Expand Down Expand Up @@ -330,6 +342,8 @@ class Elemwise(OpenMPOp):
"""

__props__ = ("scalar_op", "inplace_pattern")
# Allow pattern matching on scalar_op positionally
__match_args__ = ("scalar_op",)

def __init__(
self, scalar_op, inplace_pattern=None, name=None, nfunc_spec=None, openmp=None
Expand Down
Loading
Loading