From 1b3ee8a41d636ab34dd8224b0462cb3cae7488af Mon Sep 17 00:00:00 2001 From: Yongda Fan Date: Tue, 13 Jan 2026 22:02:05 -0500 Subject: [PATCH 1/3] write empty index file if no sample provided --- pyhealth/datasets/base_dataset.py | 22 +++++++++++++++++++++- 1 file changed, 21 insertions(+), 1 deletion(-) diff --git a/pyhealth/datasets/base_dataset.py b/pyhealth/datasets/base_dataset.py index 91a1c95a..8a37a4ea 100644 --- a/pyhealth/datasets/base_dataset.py +++ b/pyhealth/datasets/base_dataset.py @@ -115,6 +115,24 @@ def _csv_tsv_gz_path(path: str) -> str: raise FileNotFoundError(f"Neither path exists: {path} or {alt_path}") + +def _litdata_empty_index(writer: BinaryWriter): + """ + Create an empty index file for LitData writer if it does not exist. This + avoids program hanging when merging empty datasets. + + Args: + writer (BinaryWriter): The writer instance. + """ + from litdata.streaming.writer import _INDEX_FILENAME + + filepath = os.path.join(writer._cache_dir, f"{writer.rank}.{_INDEX_FILENAME}") + if not os.path.exists(filepath): + config = writer.get_config() + with open(filepath, "w") as out: + json.dump({"chunks": [], "config": config}, out, sort_keys=True) + + class _ProgressContext: def __init__(self, queue: multiprocessing.queues.Queue | None, total: int, **kwargs): """ @@ -196,6 +214,7 @@ def _task_transform_fn(args: tuple[int, BaseTask, Iterable[str], pl.LazyFrame, P complete += 1 progress.put(complete) writer.done() + _litdata_empty_index(writer) logger.info(f"Worker {worker_id} finished processing patients.") @@ -252,6 +271,7 @@ def _proc_transform_fn(args: tuple[int, Path, int, int, Path]) -> None: if complete > 0: progress.put(complete) writer.done() + _litdata_empty_index(writer) logger.info(f"Worker {worker_id} finished processing samples.") @@ -902,4 +922,4 @@ def _main_guard(self, func_name: str): f"{func_name} method accessed from a non-main process. This may lead to unexpected behavior.\n" + "Consider use __name__ == '__main__' guard when using multiprocessing." ) - exit(1) \ No newline at end of file + exit(1) From cdfdbda0293d1cab5ce9137214dee0d15a86ed06 Mon Sep 17 00:00:00 2001 From: Yongda Fan Date: Tue, 13 Jan 2026 22:19:42 -0500 Subject: [PATCH 2/3] Fixup --- pyhealth/datasets/base_dataset.py | 38 ++++++++++++++++++------------- 1 file changed, 22 insertions(+), 16 deletions(-) diff --git a/pyhealth/datasets/base_dataset.py b/pyhealth/datasets/base_dataset.py index 8a37a4ea..00e27053 100644 --- a/pyhealth/datasets/base_dataset.py +++ b/pyhealth/datasets/base_dataset.py @@ -116,21 +116,27 @@ def _csv_tsv_gz_path(path: str) -> str: raise FileNotFoundError(f"Neither path exists: {path} or {alt_path}") -def _litdata_empty_index(writer: BinaryWriter): +def _litdata_merge(cache_dir: Path) -> None: """ - Create an empty index file for LitData writer if it does not exist. This - avoids program hanging when merging empty datasets. - + Merges LitData binary writer index files in the given cache directory. + Args: - writer (BinaryWriter): The writer instance. + cache_dir (Path): The cache directory containing LitData binary writer files. """ from litdata.streaming.writer import _INDEX_FILENAME + files = os.listdir(cache_dir) + + # Return if the index already exists + if _INDEX_FILENAME in files: + return - filepath = os.path.join(writer._cache_dir, f"{writer.rank}.{_INDEX_FILENAME}") - if not os.path.exists(filepath): - config = writer.get_config() - with open(filepath, "w") as out: - json.dump({"chunks": [], "config": config}, out, sort_keys=True) + index_files = [f for f in files if f.endswith(_INDEX_FILENAME)] + + # Return if there are no index files to merge + if len(index_files) == 0: + raise ValueError("There are zero samples in the dataset, please check the task and processors.") + + BinaryWriter(cache_dir=str(cache_dir), chunk_bytes="64MB").merge(num_workers=len(index_files)) class _ProgressContext: @@ -209,12 +215,13 @@ def _task_transform_fn(args: tuple[int, BaseTask, Iterable[str], pl.LazyFrame, P patient_id = patient_id[0] # Extract string from single-element list patient = Patient(patient_id=patient_id, data_source=patient_df) for sample in task(patient): + if worker_id == 1: + continue # simulate empty task writer.add_item(write_index, {"sample": pickle.dumps(sample)}) write_index += 1 complete += 1 progress.put(complete) writer.done() - _litdata_empty_index(writer) logger.info(f"Worker {worker_id} finished processing patients.") @@ -271,7 +278,6 @@ def _proc_transform_fn(args: tuple[int, Path, int, int, Path]) -> None: if complete > 0: progress.put(complete) writer.done() - _litdata_empty_index(writer) logger.info(f"Worker {worker_id} finished processing samples.") @@ -715,7 +721,7 @@ def _task_transform(self, task: BaseTask, output_dir: Path, num_workers: int) -> if num_workers == 1: logger.info("Single worker mode, processing sequentially") _task_transform_fn((0, task, patient_ids, global_event_df, output_dir)) - BinaryWriter(cache_dir=str(output_dir), chunk_bytes="64MB").merge(num_workers) + _litdata_merge(output_dir) return # spwan is required for polars in multiprocessing, see https://docs.pola.rs/user-guide/misc/multiprocessing/#summary @@ -741,7 +747,7 @@ def _task_transform(self, task: BaseTask, output_dir: Path, num_workers: int) -> while not queue.empty(): progress.update(queue.get()) result.get() # ensure exceptions are raised - BinaryWriter(cache_dir=str(output_dir), chunk_bytes="64MB").merge(num_workers) + _litdata_merge(output_dir) logger.info(f"Task transformation completed and saved to {output_dir}") except Exception as e: @@ -765,7 +771,7 @@ def _proc_transform(self, task_df: Path, output_dir: Path, num_workers: int) -> if num_workers == 1: logger.info("Single worker mode, processing sequentially") _proc_transform_fn((0, task_df, 0, num_samples, output_dir)) - BinaryWriter(cache_dir=str(output_dir), chunk_bytes="64MB").merge(num_workers) + _litdata_merge(output_dir) return ctx = multiprocessing.get_context("spawn") @@ -791,7 +797,7 @@ def _proc_transform(self, task_df: Path, output_dir: Path, num_workers: int) -> while not queue.empty(): progress.update(queue.get()) result.get() # ensure exceptions are raised - BinaryWriter(cache_dir=str(output_dir), chunk_bytes="64MB").merge(num_workers) + _litdata_merge(output_dir) logger.info(f"Processor transformation completed and saved to {output_dir}") except Exception as e: From 0a97d546be0b98042706dcc1068699830c68b022 Mon Sep 17 00:00:00 2001 From: Yongda Fan Date: Tue, 13 Jan 2026 22:30:42 -0500 Subject: [PATCH 3/3] remove test code --- pyhealth/datasets/base_dataset.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/pyhealth/datasets/base_dataset.py b/pyhealth/datasets/base_dataset.py index 00e27053..4c81743f 100644 --- a/pyhealth/datasets/base_dataset.py +++ b/pyhealth/datasets/base_dataset.py @@ -215,8 +215,6 @@ def _task_transform_fn(args: tuple[int, BaseTask, Iterable[str], pl.LazyFrame, P patient_id = patient_id[0] # Extract string from single-element list patient = Patient(patient_id=patient_id, data_source=patient_df) for sample in task(patient): - if worker_id == 1: - continue # simulate empty task writer.add_item(write_index, {"sample": pickle.dumps(sample)}) write_index += 1 complete += 1