@@ -92,16 +92,13 @@ def get_root_A(a: TensorVariable) -> tuple[TensorVariable, bool]:
9292 def find_solve_clients (var , assume_a ):
9393 clients = []
9494 for cl , idx in fgraph .clients [var ]:
95- if (
96- idx == 0
97- and isinstance (cl .op , Blockwise )
98- and isinstance (cl .op .core_op , Solve )
99- and (cl .op .core_op .assume_a == assume_a )
100- ):
101- clients .append (cl )
102- elif isinstance (cl .op , DimShuffle ) and cl .op .is_left_expand_dims :
103- # If it's a left expand_dims, recurse on the output
104- clients .extend (find_solve_clients (cl .outputs [0 ], assume_a ))
95+ match (idx , cl .op , * cl .outputs ):
96+ case (0 , Blockwise (Solve (assume_a = assume_a_var )), * _) if (
97+ assume_a_var == assume_a
98+ ):
99+ clients .append (cl )
100+ case (0 , DimShuffle (is_left_expand_dims = True ), cl_out ):
101+ clients .extend (find_solve_clients (cl_out , assume_a ))
105102 return clients
106103
107104 assume_a = node .op .core_op .assume_a
@@ -186,34 +183,34 @@ def _scan_split_non_sequence_decomposition_and_solve(
186183 changed = False
187184 while True :
188185 for inner_node in new_scan_fgraph .toposort ():
189- if (
190- isinstance (inner_node .op , Blockwise )
191- and isinstance (inner_node .op .core_op , Solve )
192- and inner_node .op .core_op .assume_a in allowed_assume_a
193- ):
194- A , _b = inner_node .inputs
195- if all (
196- (isinstance (root_inp , Constant ) or (root_inp in non_sequences ))
197- for root_inp in graph_inputs ([A ])
186+ match (inner_node .op , * inner_node .inputs ):
187+ case (Blockwise (Solve (assume_a = assume_a_var )), A , _b ) if (
188+ assume_a_var in allowed_assume_a
198189 ):
199- if new_scan_fgraph is scan_op .fgraph :
200- # Clone the first time to avoid mutating the original fgraph
201- new_scan_fgraph , equiv = new_scan_fgraph .clone_get_equiv ()
202- non_sequences = {equiv [non_seq ] for non_seq in non_sequences }
203- inner_node = equiv [inner_node ] # type: ignore
204-
205- replace_dict = _split_decomp_and_solve_steps (
206- new_scan_fgraph ,
207- inner_node ,
208- eager = True ,
209- allowed_assume_a = allowed_assume_a ,
210- )
211- assert isinstance (replace_dict , dict ) and len (replace_dict ) > 0 , (
212- "Rewrite failed"
213- )
214- new_scan_fgraph .replace_all (replace_dict .items ())
215- changed = True
216- break # Break to start over with a fresh toposort
190+ if all (
191+ (isinstance (root_inp , Constant ) or (root_inp in non_sequences ))
192+ for root_inp in graph_inputs ([A ])
193+ ):
194+ if new_scan_fgraph is scan_op .fgraph :
195+ # Clone the first time to avoid mutating the original fgraph
196+ new_scan_fgraph , equiv = new_scan_fgraph .clone_get_equiv ()
197+ non_sequences = {
198+ equiv [non_seq ] for non_seq in non_sequences
199+ }
200+ inner_node = equiv [inner_node ] # type: ignore
201+
202+ replace_dict = _split_decomp_and_solve_steps (
203+ new_scan_fgraph ,
204+ inner_node ,
205+ eager = True ,
206+ allowed_assume_a = allowed_assume_a ,
207+ )
208+ assert (
209+ isinstance (replace_dict , dict ) and len (replace_dict ) > 0
210+ ), "Rewrite failed"
211+ new_scan_fgraph .replace_all (replace_dict .items ())
212+ changed = True
213+ break # Break to start over with a fresh toposort
217214 else : # no_break
218215 break # Nothing else changed
219216
0 commit comments