From 4126cacc3ec9ae15d1c067d266fb2294433eaded Mon Sep 17 00:00:00 2001 From: Yongda Fan Date: Sat, 10 Jan 2026 12:07:34 -0500 Subject: [PATCH 1/6] ensure code reusibility --- pyhealth/datasets/base_dataset.py | 35 ++++++++++------------------- pyhealth/datasets/sample_dataset.py | 24 ++++++++++++++++++++ 2 files changed, 36 insertions(+), 23 deletions(-) diff --git a/pyhealth/datasets/base_dataset.py b/pyhealth/datasets/base_dataset.py index ec721e8c..59d66f43 100644 --- a/pyhealth/datasets/base_dataset.py +++ b/pyhealth/datasets/base_dataset.py @@ -235,30 +235,19 @@ def _proc_transform_fn(args: tuple[int, Path, int, int, Path]) -> None: writer = BinaryWriter(cache_dir=str(output_dir), chunk_bytes="64MB") dataset = litdata.StreamingDataset(str(task_df)) - complete = 0 - with open(f"{output_dir}/schema.pkl", "rb") as f: - metadata = pickle.load(f) - - input_processors = metadata["input_processors"] - output_processors = metadata["output_processors"] - - write_index = 0 - for i in range(start_idx, end_idx): - transformed: Dict[str, Any] = {} - for key, value in pickle.loads(dataset[i]["sample"]).items(): - if key in input_processors: - transformed[key] = input_processors[key].process(value) - elif key in output_processors: - transformed[key] = output_processors[key].process(value) - else: - transformed[key] = value - writer.add_item(write_index, transformed) - write_index += 1 - complete += 1 + builder = SampleBuilder.load(f"{output_dir}/schema.pkl") - if complete >= BATCH_SIZE: - progress.put(complete) - complete = 0 + complete = 0 + write_index = 0 + for i in range(start_idx, end_idx): + transformed: Dict[str, Any] = builder.transform(pickle.loads(dataset[i])) + writer.add_item(write_index, transformed) + write_index += 1 + complete += 1 + + if complete >= BATCH_SIZE: + progress.put(complete) + complete = 0 if complete > 0: progress.put(complete) diff --git a/pyhealth/datasets/sample_dataset.py b/pyhealth/datasets/sample_dataset.py index 906b06b8..92fb2a20 100644 --- a/pyhealth/datasets/sample_dataset.py +++ b/pyhealth/datasets/sample_dataset.py @@ -220,6 +220,30 @@ def save(self, path: str) -> None: } with open(path, "wb") as f: pickle.dump(metadata, f) + + @staticmethod + def load(path: str) -> "SampleBuilder": + """Load a SampleBuilder from a pickled metadata file. + + Args: + path: Location of the pickled metadata file (commonly named `schema.pkl`). + + Returns: + A SampleBuilder instance with loaded metadata. + """ + with open(path, "rb") as f: + metadata = pickle.load(f) + + builder = SampleBuilder( + input_schema=metadata["input_schema"], + output_schema=metadata["output_schema"], + ) + builder._input_processors = metadata["input_processors"] + builder._output_processors = metadata["output_processors"] + builder._patient_to_index = metadata["patient_to_index"] + builder._record_to_index = metadata["record_to_index"] + builder._fitted = True + return builder class SampleDataset(litdata.StreamingDataset): From b9e27e2ebeea87d74599dd4fc2e526122730c2ee Mon Sep 17 00:00:00 2001 From: Yongda Fan Date: Sat, 10 Jan 2026 12:18:18 -0500 Subject: [PATCH 2/6] Fix incorrect unpickle --- pyhealth/datasets/base_dataset.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pyhealth/datasets/base_dataset.py b/pyhealth/datasets/base_dataset.py index 59d66f43..91a1c95a 100644 --- a/pyhealth/datasets/base_dataset.py +++ b/pyhealth/datasets/base_dataset.py @@ -240,7 +240,7 @@ def _proc_transform_fn(args: tuple[int, Path, int, int, Path]) -> None: complete = 0 write_index = 0 for i in range(start_idx, end_idx): - transformed: Dict[str, Any] = builder.transform(pickle.loads(dataset[i])) + transformed: Dict[str, Any] = builder.transform(dataset[i]) writer.add_item(write_index, transformed) write_index += 1 complete += 1 From b1c6d31bb10886f4f6fc0eac35c10063b497f089 Mon Sep 17 00:00:00 2001 From: Yongda Fan Date: Sat, 10 Jan 2026 12:26:11 -0500 Subject: [PATCH 3/6] add ignore processor --- pyhealth/processors/__init__.py | 1 + pyhealth/processors/ignore_processor.py | 26 +++++++++++++++++++++++++ 2 files changed, 27 insertions(+) create mode 100644 pyhealth/processors/ignore_processor.py diff --git a/pyhealth/processors/__init__.py b/pyhealth/processors/__init__.py index 1b4e15ca..15512c2d 100644 --- a/pyhealth/processors/__init__.py +++ b/pyhealth/processors/__init__.py @@ -45,6 +45,7 @@ def get_processor(name: str): from .text_processor import TextProcessor from .timeseries_processor import TimeseriesProcessor from .audio_processor import AudioProcessor +from .ignore_processor import IgnoreProcessor # Expose public API __all__ = [ diff --git a/pyhealth/processors/ignore_processor.py b/pyhealth/processors/ignore_processor.py new file mode 100644 index 00000000..3d478340 --- /dev/null +++ b/pyhealth/processors/ignore_processor.py @@ -0,0 +1,26 @@ +from typing import Any, Dict, Iterable +from . import register_processor +from .base_processor import FeatureProcessor + + +@register_processor("ignore") +class IgnoreProcessor(FeatureProcessor): + """A special feature processor that marks a feature to be ignored during processing. + """ + + def __init__(self) -> None: + pass + + def process(self, value: Any) -> Any: + """This method is intentionally not implemented. + + Args: + value: Any raw field value. + + Raises: + NotImplementedError: Always raised to indicate this processor ignores the field. + """ + raise NotImplementedError("IgnoreProcessor does not implement process method.") + + def __repr__(self) -> str: + return (f"IgnoreProcessor()") From 9e96849d34165c7a568c97975b48dfe731d32100 Mon Sep 17 00:00:00 2001 From: Yongda Fan Date: Sat, 10 Jan 2026 12:26:32 -0500 Subject: [PATCH 4/6] Do not process ignore processor --- pyhealth/datasets/sample_dataset.py | 37 +++++++++++++++++++++++++---- 1 file changed, 32 insertions(+), 5 deletions(-) diff --git a/pyhealth/datasets/sample_dataset.py b/pyhealth/datasets/sample_dataset.py index 92fb2a20..728e4e8e 100644 --- a/pyhealth/datasets/sample_dataset.py +++ b/pyhealth/datasets/sample_dataset.py @@ -11,7 +11,7 @@ from litdata.utilities.train_test_split import deepcopy_dataset import copy -from ..processors import get_processor +from ..processors import get_processor, IgnoreProcessor from ..processors.base_processor import FeatureProcessor @@ -191,8 +191,14 @@ def transform(self, sample: dict[str, bytes]) -> Dict[str, Any]: transformed: Dict[str, Any] = {} for key, value in pickle.loads(sample["sample"]).items(): if key in self._input_processors: + # Skip ignored features + if isinstance(self._input_processors[key], IgnoreProcessor): + continue transformed[key] = self._input_processors[key].process(value) elif key in self._output_processors: + # Skip ignored features + if isinstance(self._output_processors[key], IgnoreProcessor): + continue transformed[key] = self._output_processors[key].process(value) else: transformed[key] = value @@ -220,14 +226,14 @@ def save(self, path: str) -> None: } with open(path, "wb") as f: pickle.dump(metadata, f) - + @staticmethod def load(path: str) -> "SampleBuilder": """Load a SampleBuilder from a pickled metadata file. Args: path: Location of the pickled metadata file (commonly named `schema.pkl`). - + Returns: A SampleBuilder instance with loaded metadata. """ @@ -300,10 +306,29 @@ def __init__( self.output_schema = metadata["output_schema"] self.input_processors = metadata["input_processors"] self.output_processors = metadata["output_processors"] + self._remove_ignored_processors() self.patient_to_index = metadata["patient_to_index"] self.record_to_index = metadata["record_to_index"] + def _remove_ignored_processors(self): + """Remove any processors that are IgnoreProcessor instances.""" + for key in ( + key + for key, proc in self.input_processors.items() + if isinstance(proc, IgnoreProcessor) + ): + del self.input_processors[key] + del self.input_schema[key] + + for key in ( + key + for key, proc in self.output_processors.items() + if isinstance(proc, IgnoreProcessor) + ): + del self.output_processors[key] + del self.output_schema[key] + def __str__(self) -> str: """Returns a string representation of the dataset. @@ -380,12 +405,12 @@ def subset(self, indices: Union[Sequence[int], slice]) -> "SampleDataset": new_dataset.reset() return new_dataset - + def close(self) -> None: """Cleans up any temporary directories used by the dataset.""" if self.input_dir.path is not None and Path(self.input_dir.path).exists(): shutil.rmtree(self.input_dir.path) - + # -------------------------------------------------------------- # Context manager support # -------------------------------------------------------------- @@ -450,6 +475,7 @@ def __init__( self.output_schema = builder.output_schema self.input_processors = builder.input_processors self.output_processors = builder.output_processors + self._remove_ignored_processors() self.patient_to_index = builder.patient_to_index self.record_to_index = builder.record_to_index @@ -506,6 +532,7 @@ def subset(self, indices: Union[Sequence[int], slice]) -> SampleDataset: def close(self) -> None: pass # No temporary directories to clean up for in-memory dataset + def create_sample_dataset( samples: List[Dict[str, Any]], input_schema: Dict[str, Any], From 85d9ac24e870c48b6712e81005c66033d49cea5c Mon Sep 17 00:00:00 2001 From: Yongda Fan Date: Sat, 10 Jan 2026 12:29:39 -0500 Subject: [PATCH 5/6] Add test, fix bugs --- pyhealth/datasets/sample_dataset.py | 8 +-- tests/core/test_ignore_processor.py | 92 +++++++++++++++++++++++++++++ 2 files changed, 96 insertions(+), 4 deletions(-) create mode 100644 tests/core/test_ignore_processor.py diff --git a/pyhealth/datasets/sample_dataset.py b/pyhealth/datasets/sample_dataset.py index 728e4e8e..30f4dd12 100644 --- a/pyhealth/datasets/sample_dataset.py +++ b/pyhealth/datasets/sample_dataset.py @@ -313,19 +313,19 @@ def __init__( def _remove_ignored_processors(self): """Remove any processors that are IgnoreProcessor instances.""" - for key in ( + for key in [ key for key, proc in self.input_processors.items() if isinstance(proc, IgnoreProcessor) - ): + ]: del self.input_processors[key] del self.input_schema[key] - for key in ( + for key in [ key for key, proc in self.output_processors.items() if isinstance(proc, IgnoreProcessor) - ): + ]: del self.output_processors[key] del self.output_schema[key] diff --git a/tests/core/test_ignore_processor.py b/tests/core/test_ignore_processor.py new file mode 100644 index 00000000..352f3559 --- /dev/null +++ b/tests/core/test_ignore_processor.py @@ -0,0 +1,92 @@ +import unittest +import shutil +import tempfile +from pathlib import Path +import pandas as pd +import dask.dataframe as dd + +from pyhealth.datasets.base_dataset import BaseDataset +from pyhealth.tasks.base_task import BaseTask +from pyhealth.processors.ignore_processor import IgnoreProcessor +from pyhealth.processors import RawProcessor + +class MockTask(BaseTask): + task_name = "test_task" + input_schema = { + "keep_field": "raw", + "ignore_field": "raw" + } + output_schema = {"label": "binary"} + + def __call__(self, patient): + return [{ + "keep_field": "keep_val", + "ignore_field": "ignore_val", + "label": 0 if patient.patient_id == "1" else 1, + "patient_id": patient.patient_id + }] + +class MockDataset(BaseDataset): + def __init__(self, root, **kwargs): + super().__init__(root=root, tables=[], **kwargs) + + def load_data(self): + return dd.from_pandas( + pd.DataFrame({ + "patient_id": ["1", "2"], + "event_type": ["visit", "visit"], + "timestamp": [pd.Timestamp("2020-01-01"), pd.Timestamp("2020-02-01")], + }), + npartitions=1 + ) + +class TestIgnoreProcessor(unittest.TestCase): + def setUp(self): + self.tmp_dir = tempfile.mkdtemp() + self.root = self.tmp_dir + self.dataset = MockDataset(root=self.root) + + def tearDown(self): + shutil.rmtree(self.tmp_dir) + + def test_ignore_processor_with_set_task(self): + task = MockTask() + + # 1. Normal set_task + ds1 = self.dataset.set_task(task) + self.assertIn("ignore_field", ds1.input_schema) + + # Check data + # We need to access the first sample. + # Since SampleDataset is a StreamingDataset, we can index it or iterate. + sample1 = ds1[0] + self.assertIn("ignore_field", sample1) + self.assertEqual(sample1["ignore_field"], "ignore_val") + + # 2. set_task with ignore processor + # We MUST provide processors for ALL fields to avoid re-population logic in SampleBuilder + ds2 = self.dataset.set_task( + task, + input_processors={ + "keep_field": RawProcessor(), + "ignore_field": IgnoreProcessor() + } + ) + + # Expectation: "ignore_field" should be removed from input_schema of the dataset + # This is what the user asked for: "result should be the input_schema & input_processors does not exists" + + # Note: Depending on current implementation, this might fail. + self.assertNotIn("ignore_field", ds2.input_schema) + self.assertNotIn("ignore_field", ds2.input_processors) + + sample2 = ds2[0] + # Expectation: "ignore_field" should NOT be in the sample data + self.assertNotIn("ignore_field", sample2) + + # 'keep_field' should still be there + self.assertIn("keep_field", sample2) + self.assertEqual(sample2["keep_field"], "keep_val") + +if __name__ == "__main__": + unittest.main() From 2eac66785b576681da7118184fcb42618258370e Mon Sep 17 00:00:00 2001 From: Yongda Fan Date: Sun, 11 Jan 2026 17:15:51 -0500 Subject: [PATCH 6/6] Update docstring --- docs/api/processors.rst | 2 ++ .../pyhealth.processors.IgnoreProcessor.rst | 9 +++++++++ pyhealth/processors/ignore_processor.py | 13 +++++++++++++ 3 files changed, 24 insertions(+) create mode 100644 docs/api/processors/pyhealth.processors.IgnoreProcessor.rst diff --git a/docs/api/processors.rst b/docs/api/processors.rst index dcdfb286..25de2fec 100644 --- a/docs/api/processors.rst +++ b/docs/api/processors.rst @@ -44,6 +44,7 @@ Available Processors - ``StageNetProcessor``: For StageNet model with lab measurements - ``StageNetTensorProcessor``: Tensor processing for StageNet - ``MultiHotProcessor``: For multi-hot encoding +- ``IgnoreProcessor``: A special feature processor that marks a feature to be ignored. Usage Examples -------------- @@ -460,6 +461,7 @@ API Reference processors/pyhealth.processors.TimeseriesProcessor processors/pyhealth.processors.TensorProcessor processors/pyhealth.processors.RawProcessor + processors/pyhealth.processors.IgnoreProcessor processors/pyhealth.processors.MultiHotProcessor processors/pyhealth.processors.StageNetProcessor processors/pyhealth.processors.StageNetTensorProcessor \ No newline at end of file diff --git a/docs/api/processors/pyhealth.processors.IgnoreProcessor.rst b/docs/api/processors/pyhealth.processors.IgnoreProcessor.rst new file mode 100644 index 00000000..aae6c322 --- /dev/null +++ b/docs/api/processors/pyhealth.processors.IgnoreProcessor.rst @@ -0,0 +1,9 @@ +pyhealth.processors.IgnoreProcessor +====================================== + +Processor to ignore a feature. + +.. autoclass:: pyhealth.processors.IgnoreProcessor + :members: + :undoc-members: + :show-inheritance: diff --git a/pyhealth/processors/ignore_processor.py b/pyhealth/processors/ignore_processor.py index 3d478340..2f8c35c3 100644 --- a/pyhealth/processors/ignore_processor.py +++ b/pyhealth/processors/ignore_processor.py @@ -6,6 +6,19 @@ @register_processor("ignore") class IgnoreProcessor(FeatureProcessor): """A special feature processor that marks a feature to be ignored during processing. + + This processor is useful when you want to remove a specific feature from the dataset + after the task function processing, but without modifying the task function itself. + + Example: + >>> from pyhealth.processors import IgnoreProcessor + >>> # Assume we have a task that outputs "feature1" and "feature2" + >>> # We want to remove "feature2" from the final dataset + >>> dataset.set_task(task, input_processors={ + ... "feature1": SequenceProcessor(code_to_index), + ... "feature2": IgnoreProcessor() + ... }) + >>> # Now samples in dataset will only contain "feature1" """ def __init__(self) -> None: