Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion docs/content/pypaimon/pytorch.md
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ You can read all the data into a `torch.utils.data.Dataset` or `torch.utils.data
from torch.utils.data import DataLoader

table_read = read_builder.new_read()
dataset = table_read.to_torch(splits, streaming=True)
dataset = table_read.to_torch(splits, streaming=True, prefetch_concurrency=2)
dataloader = DataLoader(
dataset,
batch_size=2,
Expand All @@ -58,3 +58,5 @@ for batch_idx, batch_data in enumerate(dataloader):

When the `streaming` parameter is true, it will iteratively read;
when it is false, it will read the full amount of data into memory.

**`prefetch_concurrency`** (default: 1): When streaming is true, number of threads used for parallel prefetch within each DataLoader worker. Set to a value greater than 1 to partition splits across threads and increase read throughput. Has no effect when streaming is false.
83 changes: 82 additions & 1 deletion paimon-python/pypaimon/read/datasource/torch_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,8 @@
"""
Module to read a Paimon table into PyTorch Dataset.
"""
import queue
import threading
from typing import List

import torch
Expand Down Expand Up @@ -83,19 +85,38 @@ class TorchIterDataset(IterableDataset):
rather than loading everything into memory upfront.
"""

def __init__(self, table_read: TableRead, splits: List[Split]):
_SENTINEL = 0
_ROW = 1
_ERR = 2
_PREFETCH_QUEUE_MAXSIZE = 512
_PREFETCH_PUT_TIMEOUT_SEC = 30.0
_PREFETCH_GET_TIMEOUT_SEC = 300.0
_PREFETCH_JOIN_TIMEOUT_SEC = 5.0

def __init__(self, table_read: TableRead, splits: List[Split], prefetch_concurrency: int = 1):
"""
Initialize TorchIterDataset.

Args:
table_read: TableRead instance for reading data
splits: List of splits to read
prefetch_concurrency: Number of threads to use for parallel OSS reads within
this worker (default 1). When > 1, splits are partitioned across
threads to increase read throughput.
"""
self.table_read = table_read
self.splits = splits
self.prefetch_concurrency = max(1, int(prefetch_concurrency))
# Get field names from read_type
self.field_names = [field.name for field in table_read.read_type]

def _row_to_dict(self, offset_row) -> dict:
row_dict = {}
for i, field_name in enumerate(self.field_names):
value = offset_row.get_field(i)
row_dict[field_name] = value
return row_dict

def __iter__(self):
"""
Iterate over the dataset, converting each OffsetRow to a dictionary.
Expand Down Expand Up @@ -132,6 +153,11 @@ def __iter__(self):

splits_to_process = self.splits[start_idx:end_idx]

if self.prefetch_concurrency > 1:
for row in self._iter_rows(splits_to_process):
yield row
return

worker_iterator = self.table_read.to_iterator(splits_to_process)

for offset_row in worker_iterator:
Expand All @@ -140,3 +166,58 @@ def __iter__(self):
value = offset_row.get_field(i)
row_dict[field_name] = value
yield row_dict

def _iter_rows(self, splits: List[Split]):
n = min(self.prefetch_concurrency, len(splits))
if n == 0:
return
split_groups = [splits[i::n] for i in range(n)]

q = queue.Queue(maxsize=self._PREFETCH_QUEUE_MAXSIZE)
stop = threading.Event()

def put_item(tag: int, payload):
while not stop.is_set():
try:
q.put((tag, payload), timeout=self._PREFETCH_PUT_TIMEOUT_SEC)
return True
except queue.Full:
continue
return False

def producer(split_group: List):
try:
for offset_row in self.table_read.to_iterator(split_group):
if stop.is_set():
break
row_dict = self._row_to_dict(offset_row)
if not put_item(self._ROW, row_dict):
break
put_item(self._SENTINEL, None)
except Exception as e:
put_item(self._ERR, e)

threads = [threading.Thread(target=producer, args=(split_groups[i],), daemon=True)
for i in range(n)]
for t in threads:
t.start()

try:
done = 0
while done < n:
try:
tag, payload = q.get(timeout=self._PREFETCH_GET_TIMEOUT_SEC)
except queue.Empty:
if stop.is_set():
break
continue
if tag == self._SENTINEL:
done += 1
elif tag == self._ERR:
raise payload
else:
yield payload
finally:
stop.set()
for t in threads:
t.join(timeout=self._PREFETCH_JOIN_TIMEOUT_SEC)
9 changes: 7 additions & 2 deletions paimon-python/pypaimon/read/table_read.py
Original file line number Diff line number Diff line change
Expand Up @@ -231,11 +231,16 @@ def to_ray(
**read_args
)

def to_torch(self, splits: List[Split], streaming: bool = False) -> "torch.utils.data.Dataset":
def to_torch(
self,
splits: List[Split],
streaming: bool = False,
prefetch_concurrency: int = 1,
) -> "torch.utils.data.Dataset":
"""Wrap Paimon table data to PyTorch Dataset."""
if streaming:
from pypaimon.read.datasource.torch_dataset import TorchIterDataset
dataset = TorchIterDataset(self, splits)
dataset = TorchIterDataset(self, splits, prefetch_concurrency)
return dataset
else:
from pypaimon.read.datasource.torch_dataset import TorchDataset
Expand Down
36 changes: 36 additions & 0 deletions paimon-python/pypaimon/tests/torch_read_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,42 @@ def test_torch_read(self, is_streaming: bool = False):

print(f"✓ Test passed: Successfully read {len(all_user_ids)} rows with correct data")

def test_torch_streaming_prefetch_concurrency(self):
schema = Schema.from_pyarrow_schema(self.pa_schema, partition_keys=['user_id'])
self.catalog.create_table('default.test_torch_prefetch_concurrency', schema, False)
table = self.catalog.get_table('default.test_torch_prefetch_concurrency')
self._write_test_table(table)

read_builder = table.new_read_builder().with_projection(['user_id', 'behavior'])
table_scan = read_builder.new_scan()
table_read = read_builder.new_read()
splits = table_scan.plan().splits()
self.assertGreater(len(splits), 0, "Need at least one split to test prefetch")

dataset = table_read.to_torch(splits, streaming=True, prefetch_concurrency=4)
dataloader = DataLoader(
dataset,
batch_size=2,
num_workers=0,
shuffle=False
)

all_user_ids = []
all_behaviors = []
for batch_data in dataloader:
all_user_ids.extend(batch_data['user_id'].tolist())
all_behaviors.extend(batch_data['behavior'])

sorted_data = sorted(zip(all_user_ids, all_behaviors), key=lambda x: x[0])
sorted_user_ids = [x[0] for x in sorted_data]
sorted_behaviors = [x[1] for x in sorted_data]

expected_user_ids = [1, 2, 3, 4, 5, 6, 7, 8]
expected_behaviors = ['a', 'b', 'c', 'd', 'e', 'f', 'g', 'h']
self.assertEqual(len(all_user_ids), 8, "Should read 8 rows with prefetch_concurrency")
self.assertEqual(sorted_user_ids, expected_user_ids)
self.assertEqual(sorted_behaviors, expected_behaviors)

def test_blob_torch_read(self):
"""Test end-to-end blob functionality using blob descriptors."""
import random
Expand Down