From 5c6fa7480676c77f8dfd9b986debb38fac5659a1 Mon Sep 17 00:00:00 2001 From: xiaohongbo Date: Wed, 28 Jan 2026 14:00:23 +0800 Subject: [PATCH 01/11] support multithread when to torch dataset --- .../pypaimon/read/datasource/torch_dataset.py | 124 ++++++++++++------ paimon-python/pypaimon/read/table_read.py | 6 +- 2 files changed, 87 insertions(+), 43 deletions(-) diff --git a/paimon-python/pypaimon/read/datasource/torch_dataset.py b/paimon-python/pypaimon/read/datasource/torch_dataset.py index a800295f9e8d..4fbcd3918b37 100644 --- a/paimon-python/pypaimon/read/datasource/torch_dataset.py +++ b/paimon-python/pypaimon/read/datasource/torch_dataset.py @@ -18,6 +18,9 @@ """ Module to read a Paimon table into PyTorch Dataset. """ +import math +import queue +import threading from typing import List import torch @@ -83,60 +86,101 @@ class TorchIterDataset(IterableDataset): rather than loading everything into memory upfront. """ - def __init__(self, table_read: TableRead, splits: List[Split]): + _SENTINEL = 0 + _ROW = 1 + _ERR = 2 + _PREFETCH_QUEUE_MAXSIZE = 512 + + def __init__(self, table_read: TableRead, splits: List[Split], prefetch_concurrency: int = 1): """ Initialize TorchIterDataset. Args: table_read: TableRead instance for reading data splits: List of splits to read + prefetch_concurrency: Number of threads to use for parallel OSS reads within + this worker (default 1). When > 1, splits are partitioned across + threads to increase read throughput. """ self.table_read = table_read self.splits = splits + self.prefetch_concurrency = max(1, int(prefetch_concurrency)) # Get field names from read_type self.field_names = [field.name for field in table_read.read_type] + def _row_to_dict(self, offset_row) -> dict: + row_dict = {} + for i, field_name in enumerate(self.field_names): + value = offset_row.get_field(i) + row_dict[field_name] = value + return row_dict + def __iter__(self): """ - Iterate over the dataset, converting each OffsetRow to a dictionary. - - Supports multi-worker data loading by partitioning splits across workers. - When num_workers > 0 in DataLoader, each worker will process a subset of splits. - - Yields: - row data of dict type, where keys are column names + Iterate over the dataset, converting each row to a dict keyed by field names. """ worker_info = torch.utils.data.get_worker_info() - - if worker_info is None: - # Single-process data loading, iterate over all splits - splits_to_process = self.splits + if worker_info is not None: + per_worker = int(math.ceil(len(self.splits) / float(worker_info.num_workers))) + start = worker_info.id * per_worker + end = min(start + per_worker, len(self.splits)) + splits_for_worker = self.splits[start:end] else: - # Multi-process data loading, partition splits across workers - worker_id = worker_info.id - num_workers = worker_info.num_workers - - # Calculate start and end indices for this worker - # Distribute splits evenly by slicing - total_splits = len(self.splits) - splits_per_worker = total_splits // num_workers - remainder = total_splits % num_workers - - # Workers with id < remainder get one extra split - if worker_id < remainder: - start_idx = worker_id * (splits_per_worker + 1) - end_idx = start_idx + splits_per_worker + 1 - else: - start_idx = worker_id * splits_per_worker + remainder - end_idx = start_idx + splits_per_worker - - splits_to_process = self.splits[start_idx:end_idx] - - worker_iterator = self.table_read.to_iterator(splits_to_process) - - for offset_row in worker_iterator: - row_dict = {} - for i, field_name in enumerate(self.field_names): - value = offset_row.get_field(i) - row_dict[field_name] = value - yield row_dict + splits_for_worker = self.splits + + if self.prefetch_concurrency <= 1: + iterator = self.table_read.to_iterator(splits_for_worker) + for offset_row in iterator: + yield self._row_to_dict(offset_row) + return + + n = min(self.prefetch_concurrency, len(splits_for_worker)) + split_groups = [splits_for_worker[i::n] for i in range(n)] + if n == 0: + return + + q = queue.Queue(maxsize=self._PREFETCH_QUEUE_MAXSIZE) + stop_event = threading.Event() + + def producer(thread_id: int, split_group: List): + try: + for offset_row in self.table_read.to_iterator(split_group): + if stop_event.is_set(): + break + try: + q.put((self._ROW, self._row_to_dict(offset_row)), timeout=30.0) + except queue.Full: + if stop_event.is_set(): + break + q.put((self._ROW, self._row_to_dict(offset_row))) + q.put((self._SENTINEL, thread_id)) + except Exception as e: + q.put((self._ERR, e)) + + threads = [ + threading.Thread(target=producer, args=(i, split_groups[i]), daemon=True) + for i in range(n) + ] + for t in threads: + t.start() + + try: + sentinel_count = 0 + while sentinel_count < n: + try: + tag, payload = q.get(timeout=300.0) + except queue.Empty: + if stop_event.is_set(): + break + continue + if tag == self._SENTINEL: + sentinel_count += 1 + elif tag == self._ERR: + raise payload + elif tag == self._ROW: + yield payload + finally: + stop_event.set() + for t in threads: + t.join(timeout=5.0) + diff --git a/paimon-python/pypaimon/read/table_read.py b/paimon-python/pypaimon/read/table_read.py index f546c4be6b3a..7acdb90036b8 100644 --- a/paimon-python/pypaimon/read/table_read.py +++ b/paimon-python/pypaimon/read/table_read.py @@ -204,7 +204,7 @@ def to_ray( You needn't manually set this in most cases. **read_args: Additional kwargs passed to the datasource. For example, ``per_task_row_limit`` (Ray 2.52.0+). - + See `Ray Data API `_ for details. """ @@ -231,11 +231,11 @@ def to_ray( **read_args ) - def to_torch(self, splits: List[Split], streaming: bool = False) -> "torch.utils.data.Dataset": + def to_torch(self, splits: List[Split], streaming: bool = False, prefetch_concurrency: int = 1) -> "torch.utils.data.Dataset": """Wrap Paimon table data to PyTorch Dataset.""" if streaming: from pypaimon.read.datasource.torch_dataset import TorchIterDataset - dataset = TorchIterDataset(self, splits) + dataset = TorchIterDataset(self, splits, prefetch_concurrency) return dataset else: from pypaimon.read.datasource.torch_dataset import TorchDataset From 5a7729e3631095dff5e2efa6c1b30bbb813f3cfc Mon Sep 17 00:00:00 2001 From: xiaohongbo Date: Wed, 28 Jan 2026 14:13:48 +0800 Subject: [PATCH 02/11] add test case for prefetch_concurrency --- .../pypaimon/tests/torch_read_test.py | 36 +++++++++++++++++++ 1 file changed, 36 insertions(+) diff --git a/paimon-python/pypaimon/tests/torch_read_test.py b/paimon-python/pypaimon/tests/torch_read_test.py index b6862c6cb127..6e4d5cdbb498 100644 --- a/paimon-python/pypaimon/tests/torch_read_test.py +++ b/paimon-python/pypaimon/tests/torch_read_test.py @@ -100,6 +100,42 @@ def test_torch_read(self, is_streaming: bool = False): print(f"✓ Test passed: Successfully read {len(all_user_ids)} rows with correct data") + def test_torch_streaming_prefetch_concurrency(self): + schema = Schema.from_pyarrow_schema(self.pa_schema, partition_keys=['user_id']) + self.catalog.create_table('default.test_torch_prefetch_concurrency', schema, False) + table = self.catalog.get_table('default.test_torch_prefetch_concurrency') + self._write_test_table(table) + + read_builder = table.new_read_builder().with_projection(['user_id', 'behavior']) + table_scan = read_builder.new_scan() + table_read = read_builder.new_read() + splits = table_scan.plan().splits() + self.assertGreater(len(splits), 0, "Need at least one split to test prefetch") + + dataset = table_read.to_torch(splits, streaming=True, prefetch_concurrency=4) + dataloader = DataLoader( + dataset, + batch_size=2, + num_workers=0, + shuffle=False + ) + + all_user_ids = [] + all_behaviors = [] + for batch_data in dataloader: + all_user_ids.extend(batch_data['user_id'].tolist()) + all_behaviors.extend(batch_data['behavior']) + + sorted_data = sorted(zip(all_user_ids, all_behaviors), key=lambda x: x[0]) + sorted_user_ids = [x[0] for x in sorted_data] + sorted_behaviors = [x[1] for x in sorted_data] + + expected_user_ids = [1, 2, 3, 4, 5, 6, 7, 8] + expected_behaviors = ['a', 'b', 'c', 'd', 'e', 'f', 'g', 'h'] + self.assertEqual(len(all_user_ids), 8, "Should read 8 rows with prefetch_concurrency") + self.assertEqual(sorted_user_ids, expected_user_ids) + self.assertEqual(sorted_behaviors, expected_behaviors) + def test_blob_torch_read(self): """Test end-to-end blob functionality using blob descriptors.""" import random From dae10c6b7c9a00075a8bda1eecc2358ba255c645 Mon Sep 17 00:00:00 2001 From: xiaohongbo Date: Wed, 28 Jan 2026 14:44:22 +0800 Subject: [PATCH 03/11] add comment back --- paimon-python/pypaimon/read/datasource/torch_dataset.py | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/paimon-python/pypaimon/read/datasource/torch_dataset.py b/paimon-python/pypaimon/read/datasource/torch_dataset.py index 4fbcd3918b37..fb6abb121b27 100644 --- a/paimon-python/pypaimon/read/datasource/torch_dataset.py +++ b/paimon-python/pypaimon/read/datasource/torch_dataset.py @@ -117,7 +117,13 @@ def _row_to_dict(self, offset_row) -> dict: def __iter__(self): """ - Iterate over the dataset, converting each row to a dict keyed by field names. + Iterate over the dataset, converting each OffsetRow to a dictionary. + + Supports multi-worker data loading by partitioning splits across workers. + When num_workers > 0 in DataLoader, each worker will process a subset of splits. + + Yields: + row data of dict type, where keys are column names. """ worker_info = torch.utils.data.get_worker_info() if worker_info is not None: From 76d1ae168bb7a0712c9337f6c1bc7a9ccc98cff1 Mon Sep 17 00:00:00 2001 From: xiaohongbo Date: Wed, 28 Jan 2026 15:34:07 +0800 Subject: [PATCH 04/11] revert --- .../pypaimon/read/datasource/torch_dataset.py | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/paimon-python/pypaimon/read/datasource/torch_dataset.py b/paimon-python/pypaimon/read/datasource/torch_dataset.py index fb6abb121b27..0e07cc5e8919 100644 --- a/paimon-python/pypaimon/read/datasource/torch_dataset.py +++ b/paimon-python/pypaimon/read/datasource/torch_dataset.py @@ -123,16 +123,18 @@ def __iter__(self): When num_workers > 0 in DataLoader, each worker will process a subset of splits. Yields: - row data of dict type, where keys are column names. + row data of dict type, where keys are column names """ worker_info = torch.utils.data.get_worker_info() - if worker_info is not None: + if worker_info is None: + # Single-process data loading, iterate over all splits + splits_for_worker = self.splits + else: + # Multi-worker: partition splits across workers per_worker = int(math.ceil(len(self.splits) / float(worker_info.num_workers))) start = worker_info.id * per_worker end = min(start + per_worker, len(self.splits)) splits_for_worker = self.splits[start:end] - else: - splits_for_worker = self.splits if self.prefetch_concurrency <= 1: iterator = self.table_read.to_iterator(splits_for_worker) From 1119e2b62c8773e9f25210bcaa3f1b3f2c240711 Mon Sep 17 00:00:00 2001 From: xiaohongbo Date: Wed, 28 Jan 2026 20:51:32 +0800 Subject: [PATCH 05/11] clean code --- .../pypaimon/read/datasource/torch_dataset.py | 48 +++++++++---------- 1 file changed, 23 insertions(+), 25 deletions(-) diff --git a/paimon-python/pypaimon/read/datasource/torch_dataset.py b/paimon-python/pypaimon/read/datasource/torch_dataset.py index 0e07cc5e8919..0b78e9534881 100644 --- a/paimon-python/pypaimon/read/datasource/torch_dataset.py +++ b/paimon-python/pypaimon/read/datasource/torch_dataset.py @@ -127,68 +127,66 @@ def __iter__(self): """ worker_info = torch.utils.data.get_worker_info() if worker_info is None: - # Single-process data loading, iterate over all splits - splits_for_worker = self.splits + splits_to_process = self.splits else: - # Multi-worker: partition splits across workers per_worker = int(math.ceil(len(self.splits) / float(worker_info.num_workers))) start = worker_info.id * per_worker end = min(start + per_worker, len(self.splits)) - splits_for_worker = self.splits[start:end] + splits_to_process = self.splits[start:end] + for row in self._iter_rows(splits_to_process): + yield row + + def _iter_rows(self, splits: List[Split]): if self.prefetch_concurrency <= 1: - iterator = self.table_read.to_iterator(splits_for_worker) - for offset_row in iterator: + for offset_row in self.table_read.to_iterator(splits): yield self._row_to_dict(offset_row) return - n = min(self.prefetch_concurrency, len(splits_for_worker)) - split_groups = [splits_for_worker[i::n] for i in range(n)] + n = min(self.prefetch_concurrency, len(splits)) if n == 0: return + split_groups = [splits[i::n] for i in range(n)] q = queue.Queue(maxsize=self._PREFETCH_QUEUE_MAXSIZE) - stop_event = threading.Event() + stop = threading.Event() - def producer(thread_id: int, split_group: List): + def producer(split_group: List): try: for offset_row in self.table_read.to_iterator(split_group): - if stop_event.is_set(): + if stop.is_set(): break try: q.put((self._ROW, self._row_to_dict(offset_row)), timeout=30.0) except queue.Full: - if stop_event.is_set(): - break - q.put((self._ROW, self._row_to_dict(offset_row))) - q.put((self._SENTINEL, thread_id)) + if not stop.is_set(): + q.put((self._ROW, self._row_to_dict(offset_row))) + q.put((self._SENTINEL, None)) except Exception as e: q.put((self._ERR, e)) - threads = [ - threading.Thread(target=producer, args=(i, split_groups[i]), daemon=True) - for i in range(n) - ] + threads = [threading.Thread(target=producer, args=(split_groups[i],), daemon=True) + for i in range(n)] for t in threads: t.start() try: - sentinel_count = 0 - while sentinel_count < n: + done = 0 + while done < n: try: tag, payload = q.get(timeout=300.0) except queue.Empty: - if stop_event.is_set(): + if stop.is_set(): break continue if tag == self._SENTINEL: - sentinel_count += 1 + done += 1 elif tag == self._ERR: raise payload - elif tag == self._ROW: + else: yield payload finally: - stop_event.set() + stop.set() for t in threads: t.join(timeout=5.0) From afbb670cf9fb69486304e40f5874bff8f4595650 Mon Sep 17 00:00:00 2001 From: xiaohongbo Date: Wed, 28 Jan 2026 20:56:48 +0800 Subject: [PATCH 06/11] clean code --- .../pypaimon/read/datasource/torch_dataset.py | 43 +++++++++++++------ 1 file changed, 31 insertions(+), 12 deletions(-) diff --git a/paimon-python/pypaimon/read/datasource/torch_dataset.py b/paimon-python/pypaimon/read/datasource/torch_dataset.py index 0b78e9534881..27e66e5267fc 100644 --- a/paimon-python/pypaimon/read/datasource/torch_dataset.py +++ b/paimon-python/pypaimon/read/datasource/torch_dataset.py @@ -18,7 +18,6 @@ """ Module to read a Paimon table into PyTorch Dataset. """ -import math import queue import threading from typing import List @@ -126,23 +125,43 @@ def __iter__(self): row data of dict type, where keys are column names """ worker_info = torch.utils.data.get_worker_info() + if worker_info is None: + # Single-process data loading, iterate over all splits splits_to_process = self.splits else: - per_worker = int(math.ceil(len(self.splits) / float(worker_info.num_workers))) - start = worker_info.id * per_worker - end = min(start + per_worker, len(self.splits)) - splits_to_process = self.splits[start:end] + # Multi-process data loading, partition splits across workers + worker_id = worker_info.id + num_workers = worker_info.num_workers + + total_splits = len(self.splits) + splits_per_worker = total_splits // num_workers + remainder = total_splits % num_workers + + if worker_id < remainder: + start_idx = worker_id * (splits_per_worker + 1) + end_idx = start_idx + splits_per_worker + 1 + else: + start_idx = worker_id * splits_per_worker + remainder + end_idx = start_idx + splits_per_worker + + splits_to_process = self.splits[start_idx:end_idx] + + if self.prefetch_concurrency > 1: + for row in self._iter_rows(splits_to_process): + yield row + return - for row in self._iter_rows(splits_to_process): - yield row + worker_iterator = self.table_read.to_iterator(splits_to_process) - def _iter_rows(self, splits: List[Split]): - if self.prefetch_concurrency <= 1: - for offset_row in self.table_read.to_iterator(splits): - yield self._row_to_dict(offset_row) - return + for offset_row in worker_iterator: + row_dict = {} + for i, field_name in enumerate(self.field_names): + value = offset_row.get_field(i) + row_dict[field_name] = value + yield row_dict + def _iter_rows(self, splits: List[Split]): n = min(self.prefetch_concurrency, len(splits)) if n == 0: return From 96ea77e800cf0574b039184f3ef7a5f4886cdd64 Mon Sep 17 00:00:00 2001 From: xiaohongbo Date: Wed, 28 Jan 2026 21:00:04 +0800 Subject: [PATCH 07/11] clean code --- paimon-python/pypaimon/read/table_read.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/paimon-python/pypaimon/read/table_read.py b/paimon-python/pypaimon/read/table_read.py index 7acdb90036b8..c1744f69d7aa 100644 --- a/paimon-python/pypaimon/read/table_read.py +++ b/paimon-python/pypaimon/read/table_read.py @@ -204,7 +204,7 @@ def to_ray( You needn't manually set this in most cases. **read_args: Additional kwargs passed to the datasource. For example, ``per_task_row_limit`` (Ray 2.52.0+). - + See `Ray Data API `_ for details. """ From d01966ab2fc4b3a066f4af5317219aa17e7a5258 Mon Sep 17 00:00:00 2001 From: xiaohongbo Date: Wed, 28 Jan 2026 21:05:43 +0800 Subject: [PATCH 08/11] clean code --- .../pypaimon/read/datasource/torch_dataset.py | 12 +++++++++--- 1 file changed, 9 insertions(+), 3 deletions(-) diff --git a/paimon-python/pypaimon/read/datasource/torch_dataset.py b/paimon-python/pypaimon/read/datasource/torch_dataset.py index 27e66e5267fc..0f061ca0119f 100644 --- a/paimon-python/pypaimon/read/datasource/torch_dataset.py +++ b/paimon-python/pypaimon/read/datasource/torch_dataset.py @@ -89,6 +89,9 @@ class TorchIterDataset(IterableDataset): _ROW = 1 _ERR = 2 _PREFETCH_QUEUE_MAXSIZE = 512 + _PREFETCH_PUT_TIMEOUT_SEC = 30.0 + _PREFETCH_GET_TIMEOUT_SEC = 300.0 + _PREFETCH_JOIN_TIMEOUT_SEC = 5.0 def __init__(self, table_read: TableRead, splits: List[Split], prefetch_concurrency: int = 1): """ @@ -134,10 +137,13 @@ def __iter__(self): worker_id = worker_info.id num_workers = worker_info.num_workers + # Calculate start and end indices for this worker + # Distribute splits evenly by slicing total_splits = len(self.splits) splits_per_worker = total_splits // num_workers remainder = total_splits % num_workers + # Workers with id < remainder get one extra split if worker_id < remainder: start_idx = worker_id * (splits_per_worker + 1) end_idx = start_idx + splits_per_worker + 1 @@ -176,7 +182,7 @@ def producer(split_group: List): if stop.is_set(): break try: - q.put((self._ROW, self._row_to_dict(offset_row)), timeout=30.0) + q.put((self._ROW, self._row_to_dict(offset_row)), timeout=self._PREFETCH_PUT_TIMEOUT_SEC) except queue.Full: if not stop.is_set(): q.put((self._ROW, self._row_to_dict(offset_row))) @@ -193,7 +199,7 @@ def producer(split_group: List): done = 0 while done < n: try: - tag, payload = q.get(timeout=300.0) + tag, payload = q.get(timeout=self._PREFETCH_GET_TIMEOUT_SEC) except queue.Empty: if stop.is_set(): break @@ -207,5 +213,5 @@ def producer(split_group: List): finally: stop.set() for t in threads: - t.join(timeout=5.0) + t.join(timeout=self._PREFETCH_JOIN_TIMEOUT_SEC) From 288cfec34df2cdd796f9b4f1b5758e66240cdbe5 Mon Sep 17 00:00:00 2001 From: xiaohongbo Date: Wed, 28 Jan 2026 21:22:49 +0800 Subject: [PATCH 09/11] fix hang producer threads indefinitely issue --- .../pypaimon/read/datasource/torch_dataset.py | 21 ++++++++++++------- 1 file changed, 14 insertions(+), 7 deletions(-) diff --git a/paimon-python/pypaimon/read/datasource/torch_dataset.py b/paimon-python/pypaimon/read/datasource/torch_dataset.py index 0f061ca0119f..e432908bc43e 100644 --- a/paimon-python/pypaimon/read/datasource/torch_dataset.py +++ b/paimon-python/pypaimon/read/datasource/torch_dataset.py @@ -176,19 +176,26 @@ def _iter_rows(self, splits: List[Split]): q = queue.Queue(maxsize=self._PREFETCH_QUEUE_MAXSIZE) stop = threading.Event() + def put_item(tag: int, payload): + while not stop.is_set(): + try: + q.put((tag, payload), timeout=self._PREFETCH_PUT_TIMEOUT_SEC) + return True + except queue.Full: + continue + return False + def producer(split_group: List): try: for offset_row in self.table_read.to_iterator(split_group): if stop.is_set(): break - try: - q.put((self._ROW, self._row_to_dict(offset_row)), timeout=self._PREFETCH_PUT_TIMEOUT_SEC) - except queue.Full: - if not stop.is_set(): - q.put((self._ROW, self._row_to_dict(offset_row))) - q.put((self._SENTINEL, None)) + row_dict = self._row_to_dict(offset_row) + if not put_item(self._ROW, row_dict): + break + put_item(self._SENTINEL, None) except Exception as e: - q.put((self._ERR, e)) + put_item(self._ERR, e) threads = [threading.Thread(target=producer, args=(split_groups[i],), daemon=True) for i in range(n)] From c01854ceedce2d5ba373aa0ef2bc2705a965e5cf Mon Sep 17 00:00:00 2001 From: xiaohongbo Date: Wed, 28 Jan 2026 22:05:50 +0800 Subject: [PATCH 10/11] fix code format --- paimon-python/pypaimon/read/datasource/torch_dataset.py | 1 - paimon-python/pypaimon/read/table_read.py | 7 ++++++- 2 files changed, 6 insertions(+), 2 deletions(-) diff --git a/paimon-python/pypaimon/read/datasource/torch_dataset.py b/paimon-python/pypaimon/read/datasource/torch_dataset.py index e432908bc43e..97ebc5356636 100644 --- a/paimon-python/pypaimon/read/datasource/torch_dataset.py +++ b/paimon-python/pypaimon/read/datasource/torch_dataset.py @@ -221,4 +221,3 @@ def producer(split_group: List): stop.set() for t in threads: t.join(timeout=self._PREFETCH_JOIN_TIMEOUT_SEC) - diff --git a/paimon-python/pypaimon/read/table_read.py b/paimon-python/pypaimon/read/table_read.py index c1744f69d7aa..dcda789191a4 100644 --- a/paimon-python/pypaimon/read/table_read.py +++ b/paimon-python/pypaimon/read/table_read.py @@ -231,7 +231,12 @@ def to_ray( **read_args ) - def to_torch(self, splits: List[Split], streaming: bool = False, prefetch_concurrency: int = 1) -> "torch.utils.data.Dataset": + def to_torch( + self, + splits: List[Split], + streaming: bool = False, + prefetch_concurrency: int = 1, + ) -> "torch.utils.data.Dataset": """Wrap Paimon table data to PyTorch Dataset.""" if streaming: from pypaimon.read.datasource.torch_dataset import TorchIterDataset From 2f7b3180e945b36cdb0ad77217c44606a4b8adad Mon Sep 17 00:00:00 2001 From: xiaohongbo Date: Sat, 31 Jan 2026 16:54:18 +0800 Subject: [PATCH 11/11] add doc for prefetch_concurrency --- docs/content/pypaimon/pytorch.md | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/docs/content/pypaimon/pytorch.md b/docs/content/pypaimon/pytorch.md index b34f49edcd6d..6ab485f2c989 100644 --- a/docs/content/pypaimon/pytorch.md +++ b/docs/content/pypaimon/pytorch.md @@ -37,7 +37,7 @@ You can read all the data into a `torch.utils.data.Dataset` or `torch.utils.data from torch.utils.data import DataLoader table_read = read_builder.new_read() -dataset = table_read.to_torch(splits, streaming=True) +dataset = table_read.to_torch(splits, streaming=True, prefetch_concurrency=2) dataloader = DataLoader( dataset, batch_size=2, @@ -58,3 +58,5 @@ for batch_idx, batch_data in enumerate(dataloader): When the `streaming` parameter is true, it will iteratively read; when it is false, it will read the full amount of data into memory. + +**`prefetch_concurrency`** (default: 1): When streaming is true, number of threads used for parallel prefetch within each DataLoader worker. Set to a value greater than 1 to partition splits across threads and increase read throughput. Has no effect when streaming is false.