From 49ea92d2130bebd0aa0750125ba9c21d375cd578 Mon Sep 17 00:00:00 2001 From: Yongda Fan Date: Tue, 13 Jan 2026 21:21:49 -0500 Subject: [PATCH 1/4] skip proc transformation if it already exists --- pyhealth/datasets/base_dataset.py | 92 ++++++++++++++++++++----------- 1 file changed, 60 insertions(+), 32 deletions(-) diff --git a/pyhealth/datasets/base_dataset.py b/pyhealth/datasets/base_dataset.py index 91a1c95a..62d10a09 100644 --- a/pyhealth/datasets/base_dataset.py +++ b/pyhealth/datasets/base_dataset.py @@ -842,51 +842,79 @@ def set_task( ) if cache_dir is None: - cache_dir = self.cache_dir / "tasks" / f"{task.task_name}_{uuid.uuid5(uuid.NAMESPACE_DNS, task_params)}" + cache_dir = self.cache_dir / "tasks" / f"{task.task_name}_{uuid.uuid5(uuid.NAMESPACE_DNS, task_params).hex}" cache_dir.mkdir(parents=True, exist_ok=True) else: # Ensure the explicitly provided cache_dir exists cache_dir = Path(cache_dir) cache_dir.mkdir(parents=True, exist_ok=True) + proc_params = json.dumps( + { + "input_schema": task.input_schema, + "output_schema": task.output_schema, + "input_processors": ( + { + f"{k}_{v.__class__.__name__}": vars(v) + for k, v in input_processors.items() + } + if input_processors + else None + ), + "output_processors": ( + { + f"{k}_{v.__class__.__name__}": vars(v) + for k, v in output_processors.items() + } + if output_processors + else None + ), + } + ) + task_df_path = Path(cache_dir) / "task_df.ld" - samples_path = Path(cache_dir) / f"samples_{uuid.uuid4()}.ld" + samples_path = Path(cache_dir) / f"samples_{uuid.uuid5(uuid.NAMESPACE_DNS, proc_params).hex}.ld" task_df_path.mkdir(parents=True, exist_ok=True) samples_path.mkdir(parents=True, exist_ok=True) + + if not (samples_path / "index.json").exists(): + # Check if index.json exists to verify cache integrity, this + # is the standard file for litdata.StreamingDataset + if not (task_df_path / "index.json").exists(): + self._task_transform( + task, + task_df_path, + num_workers, + ) + else: + logger.info(f"Found cached task dataframe at {task_df_path}, skipping task transformation.") - # Check if index.json exists to verify cache integrity, this - # is the standard file for litdata.StreamingDataset - if not (task_df_path / "index.json").exists(): - self._task_transform( - task, + # Build processors and fit on the dataset + logger.info(f"Fitting processors on the dataset...") + dataset = litdata.StreamingDataset( + str(task_df_path), + transform=lambda x: pickle.loads(x["sample"]), + ) + builder = SampleBuilder( + input_schema=task.input_schema, # type: ignore + output_schema=task.output_schema, # type: ignore + input_processors=input_processors, + output_processors=output_processors, + ) + builder.fit(dataset) + builder.save(str(samples_path / "schema.pkl")) + + # Apply processors and save final samples to cache_dir + logger.info(f"Processing samples and saving to {samples_path}...") + self._proc_transform( task_df_path, + samples_path, num_workers, ) - - # Build processors and fit on the dataset - logger.info(f"Fitting processors on the dataset...") - dataset = litdata.StreamingDataset( - str(task_df_path), - transform=lambda x: pickle.loads(x["sample"]), - ) - builder = SampleBuilder( - input_schema=task.input_schema, # type: ignore - output_schema=task.output_schema, # type: ignore - input_processors=input_processors, - output_processors=output_processors, - ) - builder.fit(dataset) - builder.save(str(samples_path / "schema.pkl")) - - # Apply processors and save final samples to cache_dir - logger.info(f"Processing samples and saving to {samples_path}...") - self._proc_transform( - task_df_path, - samples_path, - num_workers, - ) - logger.info(f"Cached processed samples to {samples_path}") + logger.info(f"Cached processed samples to {samples_path}") + else: + logger.info(f"Found cached processed samples at {samples_path}, skipping processing.") return SampleDataset( path=str(samples_path), @@ -902,4 +930,4 @@ def _main_guard(self, func_name: str): f"{func_name} method accessed from a non-main process. This may lead to unexpected behavior.\n" + "Consider use __name__ == '__main__' guard when using multiprocessing." ) - exit(1) \ No newline at end of file + exit(1) From 7c1ec3abd8750f6d9aca94c2df3b3ae73f8cd65a Mon Sep 17 00:00:00 2001 From: Yongda Fan Date: Tue, 13 Jan 2026 21:30:30 -0500 Subject: [PATCH 2/4] Fix sort_key --- pyhealth/datasets/base_dataset.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/pyhealth/datasets/base_dataset.py b/pyhealth/datasets/base_dataset.py index 62d10a09..967c198a 100644 --- a/pyhealth/datasets/base_dataset.py +++ b/pyhealth/datasets/base_dataset.py @@ -869,7 +869,9 @@ def set_task( if output_processors else None ), - } + }, + sort_keys=True, + default=str ) task_df_path = Path(cache_dir) / "task_df.ld" From 616e87bc513efd688096755406eeaa8df45fe8f5 Mon Sep 17 00:00:00 2001 From: Yongda Fan Date: Tue, 13 Jan 2026 21:46:51 -0500 Subject: [PATCH 3/4] Fix test --- tests/core/test_caching.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/tests/core/test_caching.py b/tests/core/test_caching.py index fa832c31..c6a3536f 100644 --- a/tests/core/test_caching.py +++ b/tests/core/test_caching.py @@ -167,7 +167,7 @@ def test_default_cache_dir_is_used(self): default=str ) - task_cache = self.dataset.cache_dir / "tasks" / f"{self.task.task_name}_{uuid.uuid5(uuid.NAMESPACE_DNS, task_params)}" + task_cache = self.dataset.cache_dir / "tasks" / f"{self.task.task_name}_{uuid.uuid5(uuid.NAMESPACE_DNS, task_params).hex}" sample_dataset = self.dataset.set_task(self.task) self.assertTrue(task_cache.exists()) @@ -211,8 +211,8 @@ def test_tasks_with_diff_param_values_get_diff_caches(self): default=str ) - task_cache1 = self.dataset.cache_dir / "tasks" / f"{self.task.task_name}_{uuid.uuid5(uuid.NAMESPACE_DNS, task_params1)}" - task_cache2 = self.dataset.cache_dir / "tasks" / f"{self.task.task_name}_{uuid.uuid5(uuid.NAMESPACE_DNS, task_params2)}" + task_cache1 = self.dataset.cache_dir / "tasks" / f"{self.task.task_name}_{uuid.uuid5(uuid.NAMESPACE_DNS, task_params1).hex}" + task_cache2 = self.dataset.cache_dir / "tasks" / f"{self.task.task_name}_{uuid.uuid5(uuid.NAMESPACE_DNS, task_params2).hex}" self.assertTrue(task_cache1.exists()) self.assertTrue(task_cache2.exists()) From b6a89c8f0885f3f869d6c082e4f73d3e04a9db7a Mon Sep 17 00:00:00 2001 From: Yongda Fan Date: Tue, 13 Jan 2026 21:49:38 -0500 Subject: [PATCH 4/4] remove hex for backward compitability --- pyhealth/datasets/base_dataset.py | 4 ++-- tests/core/test_caching.py | 6 +++--- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/pyhealth/datasets/base_dataset.py b/pyhealth/datasets/base_dataset.py index 967c198a..f40676d8 100644 --- a/pyhealth/datasets/base_dataset.py +++ b/pyhealth/datasets/base_dataset.py @@ -842,7 +842,7 @@ def set_task( ) if cache_dir is None: - cache_dir = self.cache_dir / "tasks" / f"{task.task_name}_{uuid.uuid5(uuid.NAMESPACE_DNS, task_params).hex}" + cache_dir = self.cache_dir / "tasks" / f"{task.task_name}_{uuid.uuid5(uuid.NAMESPACE_DNS, task_params)}" cache_dir.mkdir(parents=True, exist_ok=True) else: # Ensure the explicitly provided cache_dir exists @@ -875,7 +875,7 @@ def set_task( ) task_df_path = Path(cache_dir) / "task_df.ld" - samples_path = Path(cache_dir) / f"samples_{uuid.uuid5(uuid.NAMESPACE_DNS, proc_params).hex}.ld" + samples_path = Path(cache_dir) / f"samples_{uuid.uuid5(uuid.NAMESPACE_DNS, proc_params)}.ld" task_df_path.mkdir(parents=True, exist_ok=True) samples_path.mkdir(parents=True, exist_ok=True) diff --git a/tests/core/test_caching.py b/tests/core/test_caching.py index c6a3536f..fa832c31 100644 --- a/tests/core/test_caching.py +++ b/tests/core/test_caching.py @@ -167,7 +167,7 @@ def test_default_cache_dir_is_used(self): default=str ) - task_cache = self.dataset.cache_dir / "tasks" / f"{self.task.task_name}_{uuid.uuid5(uuid.NAMESPACE_DNS, task_params).hex}" + task_cache = self.dataset.cache_dir / "tasks" / f"{self.task.task_name}_{uuid.uuid5(uuid.NAMESPACE_DNS, task_params)}" sample_dataset = self.dataset.set_task(self.task) self.assertTrue(task_cache.exists()) @@ -211,8 +211,8 @@ def test_tasks_with_diff_param_values_get_diff_caches(self): default=str ) - task_cache1 = self.dataset.cache_dir / "tasks" / f"{self.task.task_name}_{uuid.uuid5(uuid.NAMESPACE_DNS, task_params1).hex}" - task_cache2 = self.dataset.cache_dir / "tasks" / f"{self.task.task_name}_{uuid.uuid5(uuid.NAMESPACE_DNS, task_params2).hex}" + task_cache1 = self.dataset.cache_dir / "tasks" / f"{self.task.task_name}_{uuid.uuid5(uuid.NAMESPACE_DNS, task_params1)}" + task_cache2 = self.dataset.cache_dir / "tasks" / f"{self.task.task_name}_{uuid.uuid5(uuid.NAMESPACE_DNS, task_params2)}" self.assertTrue(task_cache1.exists()) self.assertTrue(task_cache2.exists())