@@ -132,7 +132,7 @@ def __init__(self, *, input_ndim: int, new_order: Sequence[int | Literal["x"]]):
132132 raise TypeError (f"input_ndim must be an integer, got { type (int )} " )
133133
134134 self .input_ndim = input_ndim
135- self .new_order = tuple (new_order )
135+ self .new_order = new_order = tuple (new_order )
136136 self ._new_order = [(- 1 if x == "x" else x ) for x in self .new_order ]
137137
138138 for i , j in enumerate (new_order ):
@@ -153,28 +153,38 @@ def __init__(self, *, input_ndim: int, new_order: Sequence[int | Literal["x"]]):
153153 f"twice in the list of output dimensions: { new_order } "
154154 )
155155
156- # List of input dimensions to drop
157- drop = [i for i in range (input_ndim ) if i not in new_order ]
158-
159- # This is the list of the original dimensions that we keep
160- self .shuffle = [x for x in new_order if x != "x" ]
161- self .transposition = self .shuffle + drop
162- # List of dimensions of the output that are broadcastable and were not
163- # in the original input
164- self .augment = augment = sorted (i for i , x in enumerate (new_order ) if x == "x" )
165- self .drop = drop
156+ # Tuple of the original dimensions that we keep
157+ self .shuffle = tuple (x for x in new_order if x != "x" )
158+ # Tuple of input dimensions to drop
159+ self .drop = drop = tuple (i for i in range (input_ndim ) if i not in new_order )
160+ # tuple of dimensions of the output that are broadcastable and were not in the original input
161+ self .augment = augment = tuple (
162+ sorted (i for i , x in enumerate (new_order ) if x == "x" )
163+ )
164+ n_augment = len (self .augment )
166165
167- dims_are_shuffled = sorted (self .shuffle ) != self .shuffle
166+ # Used by perform
167+ self ._transposition = self .shuffle + drop
168168
169- self .is_transpose = dims_are_shuffled and not augment and not drop
170- self .is_squeeze = drop and not dims_are_shuffled and not augment
171- self .is_expand_dims = augment and not dims_are_shuffled and not drop
172- self .is_left_expand_dims = self .is_expand_dims and (
173- input_ndim == 0 or new_order [- input_ndim :] == list (range (input_ndim ))
169+ # Classify the type of dimshuffle for rewrite purposes
170+ dims_are_shuffled = tuple (sorted (self .shuffle )) != self .shuffle
171+ self .is_squeeze = drop and not augment and not dims_are_shuffled
172+ self .is_expand_dims = is_expand_dims = (
173+ not drop and augment and not dims_are_shuffled
174+ )
175+ self .is_left_expand_dims = is_expand_dims and new_order [n_augment :] == tuple (
176+ range (input_ndim )
177+ )
178+ self .is_right_expand_dims = is_expand_dims and new_order [:input_ndim ] == tuple (
179+ range (input_ndim )
180+ )
181+ self .is_transpose = not drop and not augment and dims_are_shuffled
182+ self .is_left_expanded_matrix_transpose = is_left_expanded_matrix_transpose = (
183+ dims_are_shuffled
184+ and new_order [n_augment :]
185+ == (* range (input_ndim - 2 ), input_ndim - 1 , input_ndim - 2 )
174186 )
175- self .is_right_expand_dims = self .is_expand_dims and new_order [
176- :input_ndim
177- ] == list (range (input_ndim ))
187+ self .is_matrix_transpose = not augment and is_left_expanded_matrix_transpose
178188
179189 def __setstate__ (self , state ):
180190 self .__dict__ .update (state )
@@ -212,6 +222,8 @@ def make_node(self, inp):
212222 return Apply (self , [input ], [output ])
213223
214224 def __str__ (self ):
225+ if self .is_matrix_transpose :
226+ return "MatrixTranspose"
215227 if self .is_expand_dims :
216228 if len (self .augment ) == 1 :
217229 return f"ExpandDims{{axis={ self .augment [0 ]} }}"
@@ -237,7 +249,7 @@ def perform(self, node, inp, out):
237249 # )
238250
239251 # Put dropped axis at end
240- res = res .transpose (self .transposition )
252+ res = res .transpose (self ._transposition )
241253
242254 # Define new shape without dropped axis and including new ones
243255 new_shape = list (res .shape [: len (self .shuffle )])
@@ -330,6 +342,8 @@ class Elemwise(OpenMPOp):
330342 """
331343
332344 __props__ = ("scalar_op" , "inplace_pattern" )
345+ # Allow pattern matching on scalar_op positionally
346+ __match_args__ = ("scalar_op" ,)
333347
334348 def __init__ (
335349 self , scalar_op , inplace_pattern = None , name = None , nfunc_spec = None , openmp = None
0 commit comments