diff --git a/arraycontext/impl/pytato/compile.py b/arraycontext/impl/pytato/compile.py index a4f80fb4..0b2cd715 100644 --- a/arraycontext/impl/pytato/compile.py +++ b/arraycontext/impl/pytato/compile.py @@ -287,8 +287,9 @@ def _dag_to_compiled_func(self, ary_or_dict_of_named_arrays, output_template): if isinstance(ary_or_dict_of_named_arrays, pt.Array): output_id = "_pt_out" - dict_of_named_arrays = pt.make_dict_of_named_arrays( - {output_id: ary_or_dict_of_named_arrays}) + dict_of_named_arrays = pt.transform.deduplicate( + pt.make_dict_of_named_arrays( + {output_id: ary_or_dict_of_named_arrays})) pytato_program, name_in_program_to_tags, name_in_program_to_axes = ( self._dag_to_transformed_pytato_prg(dict_of_named_arrays, prg_id=self.f)) @@ -299,6 +300,8 @@ def _dag_to_compiled_func(self, ary_or_dict_of_named_arrays, output_axes=name_in_program_to_axes[output_id], output_name=output_id) elif isinstance(ary_or_dict_of_named_arrays, pt.DictOfNamedArrays): + ary_or_dict_of_named_arrays = pt.transform.deduplicate( + ary_or_dict_of_named_arrays) pytato_program, name_in_program_to_tags, name_in_program_to_axes = ( self._dag_to_transformed_pytato_prg(ary_or_dict_of_named_arrays, prg_id=self.f))