diff --git a/dataaug_platform/__init__.py b/dataaug_platform/__init__.py index 52c1281..129770c 100644 --- a/dataaug_platform/__init__.py +++ b/dataaug_platform/__init__.py @@ -7,18 +7,24 @@ local_aug, global_aug, ) -from .pipeline import Pipeline +from .dataset import SparkIterableDataset from .ingestion import load_hdf5_group, hdf5_to_rdd, read_hdf5_metadata, write_trajectories_to_hdf5 +from .pipeline import Pipeline __all__ = [ + # augmentation "Augmentation", "local_aug", "global_aug", - "Pipeline", + # dataset + "SparkIterableDataset", + # ingestion "load_hdf5_group", "hdf5_to_rdd", "read_hdf5_metadata", "write_trajectories_to_hdf5", + # Pipeline + "Pipeline", ] # Optional imports - only available if mimicgen dependencies are installed diff --git a/dataaug_platform/dataset.py b/dataaug_platform/dataset.py new file mode 100644 index 0000000..c5ccde5 --- /dev/null +++ b/dataaug_platform/dataset.py @@ -0,0 +1,52 @@ +import torch +from torch.utils.data import IterableDataset + +from .pipeline import Pipeline + +class SparkIterableDataset(IterableDataset): + def __init__( + self, + spark_pipeline: Pipeline, + to_tensor: bool = True, + infinite: bool = False, + ): + super().__init__() + self.spark_pipeline = spark_pipeline + self.to_tensor = to_tensor + self.infinite = infinite + + def _convert(self, item): + """Convert Spark output item to torch.Tensor/dict-of-tensors if needed.""" + if not self.to_tensor: + return item + + if isinstance(item, torch.Tensor): + return item + + if isinstance(item, dict): + out = {} + for k, v in item.items(): + if isinstance(v, torch.Tensor): + out[k] = v + else: + out[k] = torch.tensor(v, dtype=torch.float32) + return out + + if isinstance(item, (list, tuple)): + return torch.tensor(item, dtype=torch.float32) + + if isinstance(item, (int, float)): + return torch.tensor([item], dtype=torch.float32) + + return item + + def __iter__(self): + if self.infinite: + while True: + rdd = self.spark_pipeline.run() + for item in rdd.toLocalIterator(): + yield self._convert(item) + else: + rdd = self.spark_pipeline.run() + for item in rdd.toLocalIterator(): + yield self._convert(item) \ No newline at end of file diff --git a/dataaug_platform/pipeline.py b/dataaug_platform/pipeline.py index bb7f54a..cc7a9f9 100644 --- a/dataaug_platform/pipeline.py +++ b/dataaug_platform/pipeline.py @@ -1,30 +1,82 @@ from abc import ABC, abstractmethod +from typing import Optional + from pyspark.sql import SparkSession -from .augmentations.base_augmentation import Augmentation +from .augmentations.base_augmentation import Augmentation class Pipeline: """Manages a sequence of augmentations.""" - def __init__(self, spark=None): + def __init__(self, spark: Optional[SparkSession] = None): self.spark = ( spark or SparkSession.builder.appName("TrajectoryPipeline").getOrCreate() ) self.augmentations = [] + self._base_rdd = None # stored input RDD (for set_data-style use) + + def __enter__(self): + return self - def add(self, aug: Augmentation): + def __exit__(self, exc_type, exc_val, exc_tb): + if self.spark is not None: + self.spark.stop() + + def add(self, aug: "Augmentation"): """Add an augmentation to the pipeline.""" self.augmentations.append(aug) return self # enable chaining - def run(self, data): - """Run all augmentations sequentially.""" + def set_data(self, data, cache=False): + """ + Set the base data for the pipeline. + + `data` can be a Python list or an existing RDD. + If `cache=True`, the RDD will be cached in memory. + """ sc = self.spark.sparkContext - if not hasattr(data, "context"): # convert list ? RDD if needed + + # list / iterable -> parallelize, RDD -> keep + if hasattr(data, "context"): # looks like an RDD + rdd = data + else: rdd = sc.parallelize(data) + + if cache: + rdd = rdd.cache() + + self._base_rdd = rdd + return self + + def run(self, data=None, use_stored_if_no_data=True): + """ + Run all augmentations sequentially. + + - If `data` is provided: + behaves like the old version: converts list -> RDD if needed, does NOT + modify the stored base data. + - If `data` is None and `use_stored_if_no_data` is True: + uses the RDD set via `set_data(...)`. + """ + sc = self.spark.sparkContext + + if data is not None: + # old behavior: convert to RDD if needed + if hasattr(data, "context"): # RDD + rdd = data + else: # list / iterable + rdd = sc.parallelize(data) else: - rdd = data + if not use_stored_if_no_data: + raise ValueError("No data passed to run() and use_stored_if_no_data=False.") + if self._base_rdd is None: + raise ValueError( + "No data passed to run() and no data set via set_data(). " + "Call run(data=...) or set_data(...) first." + ) + rdd = self._base_rdd + # apply augmentations for aug in self.augmentations: rdd = aug._apply_rdd(rdd) diff --git a/examples/explore_pytorch.py b/examples/explore_pytorch.py new file mode 100644 index 0000000..889067c --- /dev/null +++ b/examples/explore_pytorch.py @@ -0,0 +1,137 @@ +""" +Example demonstrating how to use user-defined augmentation classes. + +This shows how to define augmentation classes with @local_aug and @global_aug decorators. +""" + +from torch.utils.data import DataLoader + +from pyspark.sql import SparkSession +from dataaug_platform import Augmentation, local_aug, global_aug, Pipeline, SparkIterableDataset + + +# ============================================================================ +# Example 1: User-defined local augmentation class +# ============================================================================ + +class AddOffsetAugmentation(Augmentation): + """Add a constant offset to all numeric values.""" + + def __init__(self, offset=1.0): + self.offset = offset + + @local_aug + def apply(self, traj): + """Process one trajectory at a time.""" + import numpy as np + + modified = traj.copy() + for key, value in modified.items(): + if isinstance(value, np.ndarray): + modified[key] = value + self.offset + return modified + + +# ============================================================================ +# Example 2: User-defined global augmentation class +# ============================================================================ + +class AverageTrajectoriesAugmentation(Augmentation): + """Create a new trajectory by averaging all trajectories.""" + + def __init__(self, times=1, keep_original=True): + """ + Initialize the augmentation. + + Args: + times: Number of times to run this augmentation (default: 1). + Each run processes the whole dataset and produces new trajectories. + Multiple runs are parallelized using Spark. + keep_original: Whether to keep original trajectories in output (default: True). + If False, output contains only augmented trajectories. + """ + super().__init__(times=times, keep_original=keep_original) + + @global_aug + def apply(self, trajs): + """Process all trajectories together.""" + import numpy as np + + if not trajs: + return [] + + # Average all trajectories + avg_traj = {} + for key in trajs[0].keys(): + values = [traj[key] for traj in trajs if key in traj] + if values and isinstance(values[0], np.ndarray): + avg_traj[key] = np.mean(values, axis=0) + else: + avg_traj[key] = values[0] if values else None + + return [avg_traj] + +# ============================================================================ +# Example: Using the pipeline with class-based augmentations +# ============================================================================ + +def example(): + """Demonstrates the class-based augmentation style.""" + + spark = SparkSession.builder.appName("AugmentationExample").getOrCreate() + pipeline = Pipeline(spark) + + # Add user-defined augmentation classes + pipeline.add(AddOffsetAugmentation(offset=2.0)) + # Run global augmentation 3 times in parallel using Spark + # keep_original=True (default): output has 5 (original) + 3 (augmented) = 8 trajectories + pipeline.add(AverageTrajectoriesAugmentation(times=3)) + + # Example with keep_original=False: output has only 5 augmented trajectories + # pipeline.add(AverageTrajectoriesAugmentation(times=3, keep_original=False)) + + # Sample data: list of trajectory dictionaries + sample_data = [ + {"x": [1, 2, 3], "y": [4, 5, 6]}, + {"x": [2, 3, 4], "y": [5, 6, 7]}, + {"x": [3, 4, 5], "y": [6, 7, 8]}, + {"x": [7, 8, 9], "y": [10, 11, 12]}, + {"x": [11, 12, 13], "y": [14, 15, 16]}, + ] + + pipeline.set_data(sample_data) + + print("===== Finite Dataset =====") + spark_dataset = SparkIterableDataset(pipeline) + + spark_dataloader = DataLoader(spark_dataset, batch_size=6, num_workers=1) + + for i, data in enumerate(spark_dataloader): + print(f'[{i}] {data=}') + + print("===== Infinite Dataset =====") + num_batches = 4 + batch_count = 0 + + print(f"Iteration count {num_batches}") + + spark_dataset = SparkIterableDataset(pipeline, infinite=True) + + spark_dataloader = DataLoader(spark_dataset, batch_size=6, num_workers=1) + + for i, data in enumerate(spark_dataloader): + if batch_count == num_batches: + print("Iteration count reached") + break + + print(f'[{i}] {data=}') + batch_count += 1 + + spark.stop() + + +if __name__ == "__main__": + print("=" * 60) + print("Example: Class-based augmentation style") + print("=" * 60) + example()