diff --git a/docs/content/pypaimon/pytorch.md b/docs/content/pypaimon/pytorch.md index b34f49edcd6d..6ab485f2c989 100644 --- a/docs/content/pypaimon/pytorch.md +++ b/docs/content/pypaimon/pytorch.md @@ -37,7 +37,7 @@ You can read all the data into a `torch.utils.data.Dataset` or `torch.utils.data from torch.utils.data import DataLoader table_read = read_builder.new_read() -dataset = table_read.to_torch(splits, streaming=True) +dataset = table_read.to_torch(splits, streaming=True, prefetch_concurrency=2) dataloader = DataLoader( dataset, batch_size=2, @@ -58,3 +58,5 @@ for batch_idx, batch_data in enumerate(dataloader): When the `streaming` parameter is true, it will iteratively read; when it is false, it will read the full amount of data into memory. + +**`prefetch_concurrency`** (default: 1): When streaming is true, number of threads used for parallel prefetch within each DataLoader worker. Set to a value greater than 1 to partition splits across threads and increase read throughput. Has no effect when streaming is false. diff --git a/paimon-python/pypaimon/read/datasource/torch_dataset.py b/paimon-python/pypaimon/read/datasource/torch_dataset.py index a800295f9e8d..97ebc5356636 100644 --- a/paimon-python/pypaimon/read/datasource/torch_dataset.py +++ b/paimon-python/pypaimon/read/datasource/torch_dataset.py @@ -18,6 +18,8 @@ """ Module to read a Paimon table into PyTorch Dataset. """ +import queue +import threading from typing import List import torch @@ -83,19 +85,38 @@ class TorchIterDataset(IterableDataset): rather than loading everything into memory upfront. """ - def __init__(self, table_read: TableRead, splits: List[Split]): + _SENTINEL = 0 + _ROW = 1 + _ERR = 2 + _PREFETCH_QUEUE_MAXSIZE = 512 + _PREFETCH_PUT_TIMEOUT_SEC = 30.0 + _PREFETCH_GET_TIMEOUT_SEC = 300.0 + _PREFETCH_JOIN_TIMEOUT_SEC = 5.0 + + def __init__(self, table_read: TableRead, splits: List[Split], prefetch_concurrency: int = 1): """ Initialize TorchIterDataset. Args: table_read: TableRead instance for reading data splits: List of splits to read + prefetch_concurrency: Number of threads to use for parallel OSS reads within + this worker (default 1). When > 1, splits are partitioned across + threads to increase read throughput. """ self.table_read = table_read self.splits = splits + self.prefetch_concurrency = max(1, int(prefetch_concurrency)) # Get field names from read_type self.field_names = [field.name for field in table_read.read_type] + def _row_to_dict(self, offset_row) -> dict: + row_dict = {} + for i, field_name in enumerate(self.field_names): + value = offset_row.get_field(i) + row_dict[field_name] = value + return row_dict + def __iter__(self): """ Iterate over the dataset, converting each OffsetRow to a dictionary. @@ -132,6 +153,11 @@ def __iter__(self): splits_to_process = self.splits[start_idx:end_idx] + if self.prefetch_concurrency > 1: + for row in self._iter_rows(splits_to_process): + yield row + return + worker_iterator = self.table_read.to_iterator(splits_to_process) for offset_row in worker_iterator: @@ -140,3 +166,58 @@ def __iter__(self): value = offset_row.get_field(i) row_dict[field_name] = value yield row_dict + + def _iter_rows(self, splits: List[Split]): + n = min(self.prefetch_concurrency, len(splits)) + if n == 0: + return + split_groups = [splits[i::n] for i in range(n)] + + q = queue.Queue(maxsize=self._PREFETCH_QUEUE_MAXSIZE) + stop = threading.Event() + + def put_item(tag: int, payload): + while not stop.is_set(): + try: + q.put((tag, payload), timeout=self._PREFETCH_PUT_TIMEOUT_SEC) + return True + except queue.Full: + continue + return False + + def producer(split_group: List): + try: + for offset_row in self.table_read.to_iterator(split_group): + if stop.is_set(): + break + row_dict = self._row_to_dict(offset_row) + if not put_item(self._ROW, row_dict): + break + put_item(self._SENTINEL, None) + except Exception as e: + put_item(self._ERR, e) + + threads = [threading.Thread(target=producer, args=(split_groups[i],), daemon=True) + for i in range(n)] + for t in threads: + t.start() + + try: + done = 0 + while done < n: + try: + tag, payload = q.get(timeout=self._PREFETCH_GET_TIMEOUT_SEC) + except queue.Empty: + if stop.is_set(): + break + continue + if tag == self._SENTINEL: + done += 1 + elif tag == self._ERR: + raise payload + else: + yield payload + finally: + stop.set() + for t in threads: + t.join(timeout=self._PREFETCH_JOIN_TIMEOUT_SEC) diff --git a/paimon-python/pypaimon/read/table_read.py b/paimon-python/pypaimon/read/table_read.py index f546c4be6b3a..dcda789191a4 100644 --- a/paimon-python/pypaimon/read/table_read.py +++ b/paimon-python/pypaimon/read/table_read.py @@ -231,11 +231,16 @@ def to_ray( **read_args ) - def to_torch(self, splits: List[Split], streaming: bool = False) -> "torch.utils.data.Dataset": + def to_torch( + self, + splits: List[Split], + streaming: bool = False, + prefetch_concurrency: int = 1, + ) -> "torch.utils.data.Dataset": """Wrap Paimon table data to PyTorch Dataset.""" if streaming: from pypaimon.read.datasource.torch_dataset import TorchIterDataset - dataset = TorchIterDataset(self, splits) + dataset = TorchIterDataset(self, splits, prefetch_concurrency) return dataset else: from pypaimon.read.datasource.torch_dataset import TorchDataset diff --git a/paimon-python/pypaimon/tests/torch_read_test.py b/paimon-python/pypaimon/tests/torch_read_test.py index b6862c6cb127..6e4d5cdbb498 100644 --- a/paimon-python/pypaimon/tests/torch_read_test.py +++ b/paimon-python/pypaimon/tests/torch_read_test.py @@ -100,6 +100,42 @@ def test_torch_read(self, is_streaming: bool = False): print(f"✓ Test passed: Successfully read {len(all_user_ids)} rows with correct data") + def test_torch_streaming_prefetch_concurrency(self): + schema = Schema.from_pyarrow_schema(self.pa_schema, partition_keys=['user_id']) + self.catalog.create_table('default.test_torch_prefetch_concurrency', schema, False) + table = self.catalog.get_table('default.test_torch_prefetch_concurrency') + self._write_test_table(table) + + read_builder = table.new_read_builder().with_projection(['user_id', 'behavior']) + table_scan = read_builder.new_scan() + table_read = read_builder.new_read() + splits = table_scan.plan().splits() + self.assertGreater(len(splits), 0, "Need at least one split to test prefetch") + + dataset = table_read.to_torch(splits, streaming=True, prefetch_concurrency=4) + dataloader = DataLoader( + dataset, + batch_size=2, + num_workers=0, + shuffle=False + ) + + all_user_ids = [] + all_behaviors = [] + for batch_data in dataloader: + all_user_ids.extend(batch_data['user_id'].tolist()) + all_behaviors.extend(batch_data['behavior']) + + sorted_data = sorted(zip(all_user_ids, all_behaviors), key=lambda x: x[0]) + sorted_user_ids = [x[0] for x in sorted_data] + sorted_behaviors = [x[1] for x in sorted_data] + + expected_user_ids = [1, 2, 3, 4, 5, 6, 7, 8] + expected_behaviors = ['a', 'b', 'c', 'd', 'e', 'f', 'g', 'h'] + self.assertEqual(len(all_user_ids), 8, "Should read 8 rows with prefetch_concurrency") + self.assertEqual(sorted_user_ids, expected_user_ids) + self.assertEqual(sorted_behaviors, expected_behaviors) + def test_blob_torch_read(self): """Test end-to-end blob functionality using blob descriptors.""" import random