From e7ca27e139230595521588c056a9851242bcd445 Mon Sep 17 00:00:00 2001 From: Matthew Smith Date: Tue, 1 Jul 2025 14:05:51 -0500 Subject: [PATCH] avoid duplication in PytatoJAXArrayContext.freeze --- arraycontext/impl/pytato/__init__.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/arraycontext/impl/pytato/__init__.py b/arraycontext/impl/pytato/__init__.py index 1545a7ff..35bbdcc0 100644 --- a/arraycontext/impl/pytato/__init__.py +++ b/arraycontext/impl/pytato/__init__.py @@ -954,7 +954,9 @@ def _to_frozen( rec_keyed_map_array_container(_to_frozen, array), actx=None) - pt_dict_of_named_arrays = pt.make_dict_of_named_arrays(key_to_pt_arrays) + pt_dict_of_named_arrays = pt.transform.deduplicate( + pt.make_dict_of_named_arrays(key_to_pt_arrays)) + transformed_dag = self.transform_dag(pt_dict_of_named_arrays) pt_prg = pt.generate_jax(transformed_dag, jit=True) out_dict = pt_prg()