Skip to content

Commit 8786e6e

Browse files
committed
Facilitate graph pattern matching
1 parent aefca9b commit 8786e6e

File tree

7 files changed

+53
-25
lines changed

7 files changed

+53
-25
lines changed

pytensor/graph/basic.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -558,6 +558,16 @@ def get_parents(self):
558558
return [self.owner]
559559
return []
560560

561+
@property
562+
def owner_op(self) -> Optional["Op"]:
563+
return apply.op if (apply := self.owner) is not None else None
564+
565+
@property
566+
def owner_op_and_inputs(self) -> tuple[Optional["Op"], "Variable", ...]:
567+
if (apply := self.owner) is not None:
568+
return (apply.op, *apply.inputs)
569+
return (None,)
570+
561571
def eval(
562572
self,
563573
inputs_to_values: dict[Union["Variable", str], Any] | None = None,
@@ -773,6 +783,8 @@ class Constant(AtomicVariable[_TypeType]):
773783
"""
774784

775785
# __slots__ = ['data']
786+
# Allow pattern matching on data field positionally
787+
__match_args__ = ("data",)
776788

777789
def __init__(self, type: _TypeType, data: Any, name: str | None = None):
778790
super().__init__(type, name=name)

pytensor/link/jax/dispatch/elemwise.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -72,7 +72,7 @@ def careduce(x):
7272
@jax_funcify.register(DimShuffle)
7373
def jax_funcify_DimShuffle(op, **kwargs):
7474
def dimshuffle(x):
75-
res = jnp.transpose(x, op.transposition)
75+
res = jnp.transpose(x, op._transposition)
7676

7777
shape = list(res.shape[: len(op.shuffle)])
7878

pytensor/link/mlx/dispatch/elemwise.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,7 @@ def dimshuffle(x):
5252
isinstance(x, np.number) and not isinstance(x, np.ndarray)
5353
):
5454
x = mx.array(x)
55-
res = mx.transpose(x, op.transposition)
55+
res = mx.transpose(x, op._transposition)
5656
shape = list(res.shape[: len(op.shuffle)])
5757
for augm in op.augment:
5858
shape.insert(augm, 1)

pytensor/link/pytorch/dispatch/elemwise.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -54,7 +54,7 @@ def elemwise_fn(*inputs):
5454
@pytorch_funcify.register(DimShuffle)
5555
def pytorch_funcify_DimShuffle(op, **kwargs):
5656
def dimshuffle(x):
57-
res = torch.permute(x, op.transposition)
57+
res = torch.permute(x, op._transposition)
5858

5959
shape = list(res.shape[: len(op.shuffle)])
6060

pytensor/tensor/blockwise.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -160,6 +160,8 @@ class Blockwise(COp):
160160
"""
161161

162162
__props__ = ("core_op", "signature")
163+
# Allow pattern matching on core_op positionally
164+
__match_args__ = ("core_op",)
163165

164166
def __init__(
165167
self,

pytensor/tensor/elemwise.py

Lines changed: 35 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -132,7 +132,7 @@ def __init__(self, *, input_ndim: int, new_order: Sequence[int | Literal["x"]]):
132132
raise TypeError(f"input_ndim must be an integer, got {type(int)}")
133133

134134
self.input_ndim = input_ndim
135-
self.new_order = tuple(new_order)
135+
self.new_order = new_order = tuple(new_order)
136136
self._new_order = [(-1 if x == "x" else x) for x in self.new_order]
137137

138138
for i, j in enumerate(new_order):
@@ -153,28 +153,38 @@ def __init__(self, *, input_ndim: int, new_order: Sequence[int | Literal["x"]]):
153153
f"twice in the list of output dimensions: {new_order}"
154154
)
155155

156-
# List of input dimensions to drop
157-
drop = [i for i in range(input_ndim) if i not in new_order]
158-
159-
# This is the list of the original dimensions that we keep
160-
self.shuffle = [x for x in new_order if x != "x"]
161-
self.transposition = self.shuffle + drop
162-
# List of dimensions of the output that are broadcastable and were not
163-
# in the original input
164-
self.augment = augment = sorted(i for i, x in enumerate(new_order) if x == "x")
165-
self.drop = drop
156+
# Tuple of the original dimensions that we keep
157+
self.shuffle = tuple(x for x in new_order if x != "x")
158+
# Tuple of input dimensions to drop
159+
self.drop = drop = tuple(i for i in range(input_ndim) if i not in new_order)
160+
# tuple of dimensions of the output that are broadcastable and were not in the original input
161+
self.augment = augment = tuple(
162+
sorted(i for i, x in enumerate(new_order) if x == "x")
163+
)
164+
n_augment = len(self.augment)
166165

167-
dims_are_shuffled = sorted(self.shuffle) != self.shuffle
166+
# Used by perform
167+
self._transposition = self.shuffle + drop
168168

169-
self.is_transpose = dims_are_shuffled and not augment and not drop
170-
self.is_squeeze = drop and not dims_are_shuffled and not augment
171-
self.is_expand_dims = augment and not dims_are_shuffled and not drop
172-
self.is_left_expand_dims = self.is_expand_dims and (
173-
input_ndim == 0 or new_order[-input_ndim:] == list(range(input_ndim))
169+
# Classify the type of dimshuffle for rewrite purposes
170+
dims_are_shuffled = tuple(sorted(self.shuffle)) != self.shuffle
171+
self.is_squeeze = drop and not augment and not dims_are_shuffled
172+
self.is_expand_dims = is_expand_dims = (
173+
not drop and augment and not dims_are_shuffled
174+
)
175+
self.is_left_expand_dims = is_expand_dims and new_order[n_augment:] == tuple(
176+
range(input_ndim)
177+
)
178+
self.is_right_expand_dims = is_expand_dims and new_order[:input_ndim] == tuple(
179+
range(input_ndim)
180+
)
181+
self.is_transpose = not drop and not augment and dims_are_shuffled
182+
self.is_left_expanded_matrix_transpose = is_left_expanded_matrix_transpose = (
183+
dims_are_shuffled
184+
and new_order[n_augment:]
185+
== (*range(input_ndim - 2), input_ndim - 1, input_ndim - 2)
174186
)
175-
self.is_right_expand_dims = self.is_expand_dims and new_order[
176-
:input_ndim
177-
] == list(range(input_ndim))
187+
self.is_matrix_transpose = not augment and is_left_expanded_matrix_transpose
178188

179189
def __setstate__(self, state):
180190
self.__dict__.update(state)
@@ -212,6 +222,8 @@ def make_node(self, inp):
212222
return Apply(self, [input], [output])
213223

214224
def __str__(self):
225+
if self.is_matrix_transpose:
226+
return "MatrixTranspose"
215227
if self.is_expand_dims:
216228
if len(self.augment) == 1:
217229
return f"ExpandDims{{axis={self.augment[0]}}}"
@@ -237,7 +249,7 @@ def perform(self, node, inp, out):
237249
# )
238250

239251
# Put dropped axis at end
240-
res = res.transpose(self.transposition)
252+
res = res.transpose(self._transposition)
241253

242254
# Define new shape without dropped axis and including new ones
243255
new_shape = list(res.shape[: len(self.shuffle)])
@@ -330,6 +342,8 @@ class Elemwise(OpenMPOp):
330342
"""
331343

332344
__props__ = ("scalar_op", "inplace_pattern")
345+
# Allow pattern matching on scalar_op positionally
346+
__match_args__ = ("scalar_op",)
333347

334348
def __init__(
335349
self, scalar_op, inplace_pattern=None, name=None, nfunc_spec=None, openmp=None

pytensor/tensor/rewriting/subtensor_lift.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -521,7 +521,7 @@ def local_subtensor_of_transpose(fgraph, node):
521521
if not ds_op.is_transpose:
522522
return None
523523

524-
transposition = ds_op.transposition
524+
transposition = ds_op._transposition
525525
[x] = ds.owner.inputs
526526

527527
idx_tuple = indices_from_subtensor(idx, node.op.idx_list)

0 commit comments

Comments
 (0)