From 25afff65e4b0ce459186d4676b503ca4e59bdaa5 Mon Sep 17 00:00:00 2001 From: Matthew Smith Date: Tue, 1 Jul 2025 13:00:16 -0500 Subject: [PATCH] deduplicate in _dag_to_compiled_func --- arraycontext/impl/pytato/compile.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/arraycontext/impl/pytato/compile.py b/arraycontext/impl/pytato/compile.py index a4f80fb4..0b2cd715 100644 --- a/arraycontext/impl/pytato/compile.py +++ b/arraycontext/impl/pytato/compile.py @@ -287,8 +287,9 @@ def _dag_to_compiled_func(self, ary_or_dict_of_named_arrays, output_template): if isinstance(ary_or_dict_of_named_arrays, pt.Array): output_id = "_pt_out" - dict_of_named_arrays = pt.make_dict_of_named_arrays( - {output_id: ary_or_dict_of_named_arrays}) + dict_of_named_arrays = pt.transform.deduplicate( + pt.make_dict_of_named_arrays( + {output_id: ary_or_dict_of_named_arrays})) pytato_program, name_in_program_to_tags, name_in_program_to_axes = ( self._dag_to_transformed_pytato_prg(dict_of_named_arrays, prg_id=self.f)) @@ -299,6 +300,8 @@ def _dag_to_compiled_func(self, ary_or_dict_of_named_arrays, output_axes=name_in_program_to_axes[output_id], output_name=output_id) elif isinstance(ary_or_dict_of_named_arrays, pt.DictOfNamedArrays): + ary_or_dict_of_named_arrays = pt.transform.deduplicate( + ary_or_dict_of_named_arrays) pytato_program, name_in_program_to_tags, name_in_program_to_axes = ( self._dag_to_transformed_pytato_prg(ary_or_dict_of_named_arrays, prg_id=self.f))