diff --git a/paimon-python/pypaimon/catalog/rest/rest_catalog.py b/paimon-python/pypaimon/catalog/rest/rest_catalog.py index 41a3061fb93c..425146fc6829 100644 --- a/paimon-python/pypaimon/catalog/rest/rest_catalog.py +++ b/paimon-python/pypaimon/catalog/rest/rest_catalog.py @@ -42,6 +42,10 @@ from pypaimon.snapshot.snapshot import Snapshot from pypaimon.snapshot.snapshot_commit import PartitionStatistics from pypaimon.table.file_store_table import FileStoreTable +from pypaimon.table.format.format_table import FormatTable, Format + + +FORMAT_TABLE_TYPE = "format-table" class RESTCatalog(Catalog): @@ -180,7 +184,7 @@ def list_tables_paged( except ForbiddenException as e: raise DatabaseNoPermissionException(database_name) from e - def get_table(self, identifier: Union[str, Identifier]) -> FileStoreTable: + def get_table(self, identifier: Union[str, Identifier]): if not isinstance(identifier, Identifier): identifier = Identifier.from_string(identifier) return self.load_table( @@ -263,9 +267,12 @@ def load_table(self, internal_file_io: Callable[[str], Any], external_file_io: Callable[[str], Any], metadata_loader: Callable[[Identifier], TableMetadata], - ) -> FileStoreTable: + ): metadata = metadata_loader(identifier) schema = metadata.schema + table_type = schema.options.get(CoreOptions.TYPE.key(), "").strip().lower() + if table_type == FORMAT_TABLE_TYPE: + return self._create_format_table(identifier, metadata, internal_file_io, external_file_io) data_file_io = external_file_io if metadata.is_external else internal_file_io catalog_env = CatalogEnvironment( identifier=identifier, @@ -281,6 +288,30 @@ def load_table(self, catalog_env) return table + def _create_format_table(self, + identifier: Identifier, + metadata: TableMetadata, + internal_file_io: Callable[[str], Any], + external_file_io: Callable[[str], Any], + ) -> FormatTable: + schema = metadata.schema + location = schema.options.get(CoreOptions.PATH.key()) + if not location: + raise ValueError("Format table schema must have path option") + data_file_io = external_file_io if metadata.is_external else internal_file_io + file_io = data_file_io(location) + file_format = schema.options.get(CoreOptions.FILE_FORMAT.key(), "parquet") + fmt = Format.parse(file_format) + return FormatTable( + file_io=file_io, + identifier=identifier, + table_schema=schema, + location=location, + format=fmt, + options=dict(schema.options), + comment=schema.comment, + ) + @staticmethod def create(file_io: FileIO, table_path: str, diff --git a/paimon-python/pypaimon/read/push_down_utils.py b/paimon-python/pypaimon/read/push_down_utils.py index f8123411490c..7ad7e53acccd 100644 --- a/paimon-python/pypaimon/read/push_down_utils.py +++ b/paimon-python/pypaimon/read/push_down_utils.py @@ -16,12 +16,29 @@ # limitations under the License. ################################################################################ -from typing import Dict, List, Set +from typing import Dict, List, Optional, Set from pypaimon.common.predicate import Predicate from pypaimon.common.predicate_builder import PredicateBuilder +def extract_partition_spec_from_predicate( + predicate: Predicate, partition_keys: List[str] +) -> Optional[Dict[str, str]]: + if not predicate or not partition_keys: + return None + parts = _split_and(predicate) + spec: Dict[str, str] = {} + for p in parts: + if p.method != "equal" or p.field is None or p.literals is None or len(p.literals) != 1: + continue + if p.field in partition_keys: + spec[p.field] = str(p.literals[0]) + if set(spec.keys()) == set(partition_keys): + return spec + return None + + def trim_and_transform_predicate(input_predicate: Predicate, all_fields: List[str], trimmed_keys: List[str]): new_predicate = trim_predicate_by_fields(input_predicate, trimmed_keys) part_to_index = {element: idx for idx, element in enumerate(trimmed_keys)} diff --git a/paimon-python/pypaimon/table/format/__init__.py b/paimon-python/pypaimon/table/format/__init__.py new file mode 100644 index 000000000000..228f165248ed --- /dev/null +++ b/paimon-python/pypaimon/table/format/__init__.py @@ -0,0 +1,36 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from pypaimon.table.format.format_data_split import FormatDataSplit +from pypaimon.table.format.format_table import FormatTable, Format +from pypaimon.table.format.format_read_builder import FormatReadBuilder +from pypaimon.table.format.format_table_scan import FormatTableScan +from pypaimon.table.format.format_table_read import FormatTableRead +from pypaimon.table.format.format_batch_write_builder import FormatBatchWriteBuilder +from pypaimon.table.format.format_table_write import FormatTableWrite +from pypaimon.table.format.format_table_commit import FormatTableCommit + +__all__ = [ + "FormatDataSplit", + "FormatTable", + "Format", + "FormatReadBuilder", + "FormatTableScan", + "FormatTableRead", + "FormatBatchWriteBuilder", + "FormatTableWrite", + "FormatTableCommit", +] diff --git a/paimon-python/pypaimon/table/format/format_batch_write_builder.py b/paimon-python/pypaimon/table/format/format_batch_write_builder.py new file mode 100644 index 000000000000..31d865020a13 --- /dev/null +++ b/paimon-python/pypaimon/table/format/format_batch_write_builder.py @@ -0,0 +1,55 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Optional + +from pypaimon.table.format.format_table import FormatTable +from pypaimon.table.format.format_table_commit import FormatTableCommit +from pypaimon.table.format.format_table_write import FormatTableWrite + + +class FormatBatchWriteBuilder: + def __init__(self, table: FormatTable): + self.table = table + self._overwrite = False + self._static_partition: Optional[dict] = None + + def overwrite(self, static_partition: Optional[dict] = None) -> "FormatBatchWriteBuilder": + self._overwrite = True + self._validate_static_partition(static_partition) + self._static_partition = static_partition if static_partition is not None else {} + return self + + def _validate_static_partition(self, static_partition: Optional[dict]) -> None: + if not static_partition: + return + if not self.table.partition_keys: + raise ValueError( + "Format table is not partitioned, static partition values are not allowed." + ) + for key in static_partition: + if key not in self.table.partition_keys: + raise ValueError(f"Unknown static partition column: {key}") + + def new_write(self) -> FormatTableWrite: + return FormatTableWrite( + self.table, + overwrite=self._overwrite, + static_partitions=self._static_partition, + ) + + def new_commit(self) -> FormatTableCommit: + return FormatTableCommit(table=self.table) diff --git a/paimon-python/pypaimon/table/format/format_commit_message.py b/paimon-python/pypaimon/table/format/format_commit_message.py new file mode 100644 index 000000000000..c9a253ef58f9 --- /dev/null +++ b/paimon-python/pypaimon/table/format/format_commit_message.py @@ -0,0 +1,26 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from dataclasses import dataclass +from typing import List + + +@dataclass +class FormatTableCommitMessage: + written_paths: List[str] + + def is_empty(self) -> bool: + return not self.written_paths diff --git a/paimon-python/pypaimon/table/format/format_data_split.py b/paimon-python/pypaimon/table/format/format_data_split.py new file mode 100644 index 000000000000..8536a18025d6 --- /dev/null +++ b/paimon-python/pypaimon/table/format/format_data_split.py @@ -0,0 +1,30 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from dataclasses import dataclass +from typing import Dict, Optional, Any + + +@dataclass(frozen=True) +class FormatDataSplit: + """Split for format table: one file (or future: byte range) per split.""" + + file_path: str + file_size: int + partition: Optional[Dict[str, Any]] = None # partition column name -> value + + def data_path(self) -> str: + return self.file_path diff --git a/paimon-python/pypaimon/table/format/format_read_builder.py b/paimon-python/pypaimon/table/format/format_read_builder.py new file mode 100644 index 000000000000..ea501d56f0b1 --- /dev/null +++ b/paimon-python/pypaimon/table/format/format_read_builder.py @@ -0,0 +1,82 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import List, Optional + +from pypaimon.common.predicate import Predicate +from pypaimon.common.predicate_builder import PredicateBuilder +from pypaimon.read.push_down_utils import extract_partition_spec_from_predicate +from pypaimon.schema.data_types import DataField +from pypaimon.table.format.format_table import FormatTable +from pypaimon.table.format.format_table_scan import FormatTableScan +from pypaimon.table.format.format_table_read import FormatTableRead + + +class FormatReadBuilder: + def __init__(self, table: FormatTable): + self.table = table + self._projection: Optional[List[str]] = None + self._limit: Optional[int] = None + self._partition_filter: Optional[dict] = None + + def with_filter(self, predicate: Predicate) -> "FormatReadBuilder": + ok = ( + self._partition_filter is None + and self.table.partition_keys + and predicate + ) + if ok: + spec = extract_partition_spec_from_predicate( + predicate, self.table.partition_keys + ) + if spec is not None: + self._partition_filter = spec + return self + + def with_projection(self, projection: List[str]) -> "FormatReadBuilder": + self._projection = projection + return self + + def with_limit(self, limit: int) -> "FormatReadBuilder": + self._limit = limit + return self + + def with_partition_filter( + self, partition_spec: Optional[dict] + ) -> "FormatReadBuilder": + self._partition_filter = partition_spec + return self + + def new_scan(self) -> FormatTableScan: + return FormatTableScan( + self.table, + partition_filter=self._partition_filter, + ) + + def new_read(self) -> FormatTableRead: + return FormatTableRead( + table=self.table, + projection=self._projection, + limit=self._limit, + ) + + def new_predicate_builder(self) -> PredicateBuilder: + return PredicateBuilder(self.read_type()) + + def read_type(self) -> List[DataField]: + if self._projection: + return [f for f in self.table.fields if f.name in self._projection] + return list(self.table.fields) diff --git a/paimon-python/pypaimon/table/format/format_table.py b/paimon-python/pypaimon/table/format/format_table.py new file mode 100644 index 000000000000..564bd086255d --- /dev/null +++ b/paimon-python/pypaimon/table/format/format_table.py @@ -0,0 +1,100 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from enum import Enum +from typing import Dict, List, Optional + +from pypaimon.common.file_io import FileIO +from pypaimon.common.identifier import Identifier +from pypaimon.schema.table_schema import TableSchema +from pypaimon.table.table import Table + + +class Format(str, Enum): + ORC = "orc" + PARQUET = "parquet" + CSV = "csv" + TEXT = "text" + JSON = "json" + + @classmethod + def parse(cls, file_format: str) -> "Format": + s = (file_format or "parquet").strip().upper() + try: + return cls[s] + except KeyError: + raise ValueError( + f"Format table unsupported file format: {file_format}. " + f"Supported: {[f.name for f in cls]}" + ) + + +class FormatTable(Table): + def __init__( + self, + file_io: FileIO, + identifier: Identifier, + table_schema: TableSchema, + location: str, + format: Format, + options: Optional[Dict[str, str]] = None, + comment: Optional[str] = None, + ): + self.file_io = file_io + self.identifier = identifier + self._table_schema = table_schema + self._location = location.rstrip("/") + self._format = format + self._options = options or dict(table_schema.options) + self.comment = comment + self.fields = table_schema.fields + self.field_names = [f.name for f in self.fields] + self.partition_keys = table_schema.partition_keys or [] + self.primary_keys: List[str] = [] # format table has no primary key + + def name(self) -> str: + return self.identifier.get_table_name() + + def full_name(self) -> str: + return self.identifier.get_full_name() + + @property + def table_schema(self) -> TableSchema: + return self._table_schema + + @table_schema.setter + def table_schema(self, value: TableSchema): + self._table_schema = value + + def location(self) -> str: + return self._location + + def format(self) -> Format: + return self._format + + def options(self) -> Dict[str, str]: + return self._options + + def new_read_builder(self): + from pypaimon.table.format.format_read_builder import FormatReadBuilder + return FormatReadBuilder(self) + + def new_batch_write_builder(self): + from pypaimon.table.format.format_batch_write_builder import FormatBatchWriteBuilder + return FormatBatchWriteBuilder(self) + + def new_stream_write_builder(self): + raise NotImplementedError("Format table does not support stream write.") diff --git a/paimon-python/pypaimon/table/format/format_table_commit.py b/paimon-python/pypaimon/table/format/format_table_commit.py new file mode 100644 index 000000000000..d869744590bd --- /dev/null +++ b/paimon-python/pypaimon/table/format/format_table_commit.py @@ -0,0 +1,66 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import List + +import pyarrow.fs as pafs + +from pypaimon.table.format.format_table import FormatTable +from pypaimon.table.format.format_table_scan import _is_data_file_name +from pypaimon.table.format.format_commit_message import FormatTableCommitMessage + + +def _delete_data_files_in_path(file_io, path: str) -> None: + try: + infos = file_io.list_status(path) + except Exception: + return + for info in infos: + if info.type == pafs.FileType.Directory: + _delete_data_files_in_path(file_io, info.path) + elif info.type == pafs.FileType.File: + name = info.path.split("/")[-1] if "/" in info.path else info.path + if _is_data_file_name(name): + try: + file_io.delete(info.path, False) + except Exception: + pass + + +class FormatTableCommit: + """Commit for format table. Overwrite is applied in FormatTableWrite at write time.""" + + def __init__(self, table: FormatTable): + self.table = table + self._committed = False + + def commit(self, commit_messages: List[FormatTableCommitMessage]) -> None: + if self._committed: + raise RuntimeError("FormatTableCommit supports only one commit.") + self._committed = True + return + + def abort(self, commit_messages: List[FormatTableCommitMessage]) -> None: + for msg in commit_messages: + for path in msg.written_paths: + try: + if self.table.file_io.exists(path): + self.table.file_io.delete(path, False) + except Exception: + pass + + def close(self) -> None: + pass diff --git a/paimon-python/pypaimon/table/format/format_table_read.py b/paimon-python/pypaimon/table/format/format_table_read.py new file mode 100644 index 000000000000..11ce3faf5b29 --- /dev/null +++ b/paimon-python/pypaimon/table/format/format_table_read.py @@ -0,0 +1,277 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Any, Dict, Iterator, List, Optional + +import pandas +import pyarrow + +from pypaimon.schema.data_types import PyarrowFieldParser +from pypaimon.table.format.format_data_split import FormatDataSplit +from pypaimon.table.format.format_table import FormatTable, Format + + +def _text_format_schema_column(table: FormatTable) -> Optional[str]: + """TEXT format: single data column name from schema (non-partition).""" + if table.format() != Format.TEXT or not table.fields: + return None + data_names = [ + f.name for f in table.fields if f.name not in table.partition_keys + ] + return ( + data_names[0] + if data_names + else (table.field_names[0] if table.field_names else None) + ) + + +def _read_file_to_arrow( + file_io: Any, + split: FormatDataSplit, + fmt: Format, + partition_spec: Optional[Dict[str, str]], + read_fields: Optional[List[str]], + partition_key_types: Optional[Dict[str, pyarrow.DataType]] = None, + text_column_name: Optional[str] = None, + text_line_delimiter: str = "\n", +) -> pyarrow.Table: + path = split.data_path() + csv_read_options = None + if fmt == Format.CSV and hasattr(pyarrow, "csv"): + csv_read_options = pyarrow.csv.ReadOptions(block_size=1 << 20) + try: + with file_io.new_input_stream(path) as stream: + chunks = [] + while True: + chunk = stream.read() + if not chunk: + break + chunks.append( + chunk if isinstance(chunk, bytes) else bytes(chunk) + ) + data = b"".join(chunks) + except Exception as e: + raise RuntimeError(f"Failed to read {path}") from e + + if not data or len(data) == 0: + return pyarrow.table({}) + + if fmt == Format.PARQUET: + import io + data = ( + bytes(data) if not isinstance(data, bytes) else data + ) + if len(data) < 4 or data[:4] != b"PAR1": + return pyarrow.table({}) + try: + tbl = pyarrow.parquet.read_table(io.BytesIO(data)) + except pyarrow.ArrowInvalid: + return pyarrow.table({}) + elif fmt == Format.CSV: + if hasattr(pyarrow, "csv"): + tbl = pyarrow.csv.read_csv( + pyarrow.BufferReader(data), + read_options=csv_read_options, + ) + else: + import io + df = pandas.read_csv(io.BytesIO(data)) + tbl = pyarrow.Table.from_pandas(df) + elif fmt == Format.JSON: + import json + text = data.decode("utf-8") if isinstance(data, bytes) else data + records = [] + for line in text.strip().split("\n"): + line = line.strip() + if line: + records.append(json.loads(line)) + if not records: + return pyarrow.table({}) + tbl = pyarrow.Table.from_pylist(records) + elif fmt == Format.ORC: + import io + data = bytes(data) if not isinstance(data, bytes) else data + if hasattr(pyarrow, "orc"): + try: + tbl = pyarrow.orc.read_table(io.BytesIO(data)) + except Exception: + return pyarrow.table({}) + else: + raise ValueError( + "Format table read for ORC requires PyArrow with ORC support " + "(pyarrow.orc)" + ) + elif fmt == Format.TEXT: + text = data.decode("utf-8") if isinstance(data, bytes) else data + lines = ( + text.rstrip(text_line_delimiter).split(text_line_delimiter) + if text + else [] + ) + if not lines: + return pyarrow.table({}) + part_keys = set(partition_spec.keys()) if partition_spec else set() + col_name = text_column_name if text_column_name else "value" + if read_fields: + for f in read_fields: + if f not in part_keys: + col_name = f + break + tbl = pyarrow.table({col_name: lines}) + else: + raise ValueError(f"Format {fmt} read not implemented in Python") + + if partition_spec: + for k, v in partition_spec.items(): + if k in tbl.column_names: + continue + pa_type = ( + partition_key_types.get(k, pyarrow.string()) + if partition_key_types + else pyarrow.string() + ) + arr = pyarrow.array([v] * tbl.num_rows, type=pyarrow.string()) + if pa_type != pyarrow.string(): + arr = arr.cast(pa_type) + tbl = tbl.append_column(k, arr) + + if read_fields and tbl.num_columns > 0: + existing = [c for c in read_fields if c in tbl.column_names] + if existing: + tbl = tbl.select(existing) + return tbl + + +def _partition_key_types( + table: FormatTable, +) -> Optional[Dict[str, pyarrow.DataType]]: + """Build partition column name -> PyArrow type from table schema.""" + if not table.partition_keys: + return None + result = {} + for f in table.fields: + if f.name in table.partition_keys: + pa_field = PyarrowFieldParser.from_paimon_field(f) + result[f.name] = pa_field.type + return result if result else None + + +class FormatTableRead: + + def __init__( + self, + table: FormatTable, + projection: Optional[List[str]] = None, + limit: Optional[int] = None, + ): + self.table = table + self.projection = projection + self.limit = limit + + def to_arrow( + self, + splits: List[FormatDataSplit], + ) -> pyarrow.Table: + read_fields = self.projection + fmt = self.table.format() + partition_key_types = _partition_key_types(self.table) + text_col = ( + _text_format_schema_column(self.table) + if fmt == Format.TEXT + else None + ) + text_delim = ( + self.table.options().get("text.line-delimiter", "\n") + if fmt == Format.TEXT + else "\n" + ) + tables = [] + nrows = 0 + for split in splits: + t = _read_file_to_arrow( + self.table.file_io, + split, + fmt, + split.partition, + read_fields, + partition_key_types, + text_column_name=text_col, + text_line_delimiter=text_delim, + ) + if t.num_rows > 0: + tables.append(t) + nrows += t.num_rows + if self.limit is not None and nrows >= self.limit: + if nrows > self.limit: + excess = nrows - self.limit + last = tables[-1] + tables[-1] = last.slice(0, last.num_rows - excess) + break + if not tables: + fields = self.table.fields + if read_fields: + fields = [ + f for f in self.table.fields if f.name in read_fields + ] + schema = PyarrowFieldParser.from_paimon_schema(fields) + return pyarrow.Table.from_pydict( + {n: [] for n in schema.names}, + schema=schema, + ) + out = pyarrow.concat_tables(tables) + if self.limit is not None and out.num_rows > self.limit: + out = out.slice(0, self.limit) + return out + + def to_pandas(self, splits: List[FormatDataSplit]) -> pandas.DataFrame: + return self.to_arrow(splits).to_pandas() + + def to_iterator( + self, + splits: List[FormatDataSplit], + ) -> Iterator[Any]: + partition_key_types = _partition_key_types(self.table) + fmt = self.table.format() + text_col = ( + _text_format_schema_column(self.table) + if fmt == Format.TEXT + else None + ) + text_delim = ( + self.table.options().get("text.line-delimiter", "\n") + if fmt == Format.TEXT + else "\n" + ) + n_yielded = 0 + for split in splits: + if self.limit is not None and n_yielded >= self.limit: + break + t = _read_file_to_arrow( + self.table.file_io, + split, + fmt, + split.partition, + self.projection, + partition_key_types, + text_column_name=text_col, + text_line_delimiter=text_delim, + ) + for batch in t.to_batches(): + for i in range(batch.num_rows): + if self.limit is not None and n_yielded >= self.limit: + return + yield batch.slice(i, 1) + n_yielded += 1 diff --git a/paimon-python/pypaimon/table/format/format_table_scan.py b/paimon-python/pypaimon/table/format/format_table_scan.py new file mode 100644 index 000000000000..9ce2d9ea26ad --- /dev/null +++ b/paimon-python/pypaimon/table/format/format_table_scan.py @@ -0,0 +1,130 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Dict, List, Optional + +import pyarrow.fs as pafs + +from pypaimon.common.file_io import FileIO +from pypaimon.read.plan import Plan +from pypaimon.table.format.format_data_split import FormatDataSplit +from pypaimon.table.format.format_table import FormatTable + + +def _is_data_file_name(name: str) -> bool: + if name is None: + return False + return not name.startswith(".") and not name.startswith("_") + + +def _is_reserved_dir_name(name: str) -> bool: + if not name: + return True + if name.startswith(".") or name.startswith("_"): + return True + if name.lower() in ("schema", "_schema"): + return True + return False + + +def _list_data_files_recursive( + file_io: FileIO, + path: str, + partition_keys: List[str], + partition_only_value: bool, + rel_path_parts: Optional[List[str]] = None, +) -> List[FormatDataSplit]: + splits: List[FormatDataSplit] = [] + rel_path_parts = rel_path_parts or [] + try: + infos = file_io.list_status(path) + except Exception: + return splits + if not infos: + return splits + path_rstrip = path.rstrip("/") + for info in infos: + name = info.path.split("/")[-1] if "/" in info.path else info.path + full_path = f"{path_rstrip}/{name}" if path_rstrip else name + if info.path.startswith("/") or info.path.startswith("file:"): + full_path = info.path + if info.type == pafs.FileType.Directory: + if _is_reserved_dir_name(name): + continue + part_value = name + if not partition_only_value and "=" in name: + part_value = name.split("=", 1)[1] + child_parts = rel_path_parts + [part_value] + if len(child_parts) <= len(partition_keys): + sub_splits = _list_data_files_recursive( + file_io, + full_path, + partition_keys, + partition_only_value, + child_parts, + ) + splits.extend(sub_splits) + elif info.type == pafs.FileType.File and _is_data_file_name(name): + size = getattr(info, "size", None) or 0 + part_spec: Optional[Dict[str, str]] = None + if partition_keys and len(rel_path_parts) >= len(partition_keys): + part_spec = dict( + zip( + partition_keys, + rel_path_parts[: len(partition_keys)], + ) + ) + splits.append( + FormatDataSplit( + file_path=full_path, + file_size=size, + partition=part_spec, + ) + ) + return splits + + +class FormatTableScan: + + def __init__( + self, + table: FormatTable, + partition_filter: Optional[Dict[str, str]] = None, + ): + self.table = table + self.partition_filter = partition_filter # optional equality filter + + def plan(self) -> Plan: + partition_only_value = self.table.options().get( + "format-table.partition-path-only-value", "false" + ).lower() == "true" + splits = _list_data_files_recursive( + self.table.file_io, + self.table.location(), + self.table.partition_keys, + partition_only_value, + ) + if self.partition_filter: + filtered = [] + for s in splits: + match = s.partition and all( + str(s.partition.get(k)) == str(v) + for k, v in self.partition_filter.items() + ) + if match: + filtered.append(s) + splits = filtered + return Plan(_splits=splits) diff --git a/paimon-python/pypaimon/table/format/format_table_write.py b/paimon-python/pypaimon/table/format/format_table_write.py new file mode 100644 index 000000000000..eb45b718d5c5 --- /dev/null +++ b/paimon-python/pypaimon/table/format/format_table_write.py @@ -0,0 +1,253 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import io +import uuid +from collections import defaultdict +from typing import Dict, List, Optional + +import pyarrow + +from pypaimon.schema.data_types import PyarrowFieldParser +from pypaimon.table.format.format_commit_message import ( + FormatTableCommitMessage, +) +from pypaimon.table.format.format_table import ( + Format, + FormatTable, +) + + +def _partition_path( + partition_spec: dict, partition_keys: List[str], only_value: bool +) -> str: + parts = [] + for k in partition_keys: + v = partition_spec.get(k) + if v is None: + break + parts.append(str(v) if only_value else f"{k}={v}") + return "/".join(parts) + + +def _validate_partition_columns( + partition_keys: List[str], + data: pyarrow.RecordBatch, +) -> None: + """Raise if partition key missing from data (wrong column indexing).""" + names = set(data.schema.names) if data.schema else set() + missing = [k for k in partition_keys if k not in names] + if missing: + raise ValueError( + f"Partition column(s) missing from input data: {missing}. " + f"Data columns: {list(names)}. " + "Ensure partition keys exist in the Arrow schema." + ) + + +def _partition_from_row( + row: pyarrow.RecordBatch, + partition_keys: List[str], + row_index: int, +) -> tuple: + out = [] + for k in partition_keys: + col = row.column(row.schema.get_field_index(k)) + val = col[row_index] + is_none = val is None or ( + hasattr(val, "as_py") and val.as_py() is None + ) + if is_none: + out.append(None) + else: + out.append(val.as_py() if hasattr(val, "as_py") else val) + return tuple(out) + + +class FormatTableWrite: + """Batch write for format table: Arrow/Pandas to partition dirs.""" + + def __init__( + self, + table: FormatTable, + overwrite: bool = False, + static_partitions: Optional[Dict[str, str]] = None, + ): + self.table = table + self._overwrite = overwrite + self._static_partitions = ( + static_partitions if static_partitions is not None else {} + ) + self._written_paths: List[str] = [] + self._overwritten_dirs: set = set() + opt = table.options().get( + "format-table.partition-path-only-value", "false" + ) + self._partition_only_value = opt.lower() == "true" + self._file_format = table.format() + self._data_file_prefix = "data-" + self._suffix = { + "parquet": ".parquet", + "csv": ".csv", + "json": ".json", + "orc": ".orc", + "text": ".txt", + }.get(self._file_format.value, ".parquet") + + def write_arrow(self, data: pyarrow.Table) -> None: + for batch in data.to_batches(): + self.write_arrow_batch(batch) + + def write_arrow_batch(self, data: pyarrow.RecordBatch) -> None: + partition_keys = self.table.partition_keys + if not partition_keys: + part_spec = {} + self._write_single_batch(data, part_spec) + return + _validate_partition_columns(partition_keys, data) + # Group rows by partition + parts_to_indices = defaultdict(list) + for i in range(data.num_rows): + part = _partition_from_row(data, partition_keys, i) + parts_to_indices[part].append(i) + for part_tuple, indices in parts_to_indices.items(): + part_spec = dict(zip(partition_keys, part_tuple)) + sub = data.take(pyarrow.array(indices)) + self._write_single_batch(sub, part_spec) + + def write_pandas(self, df) -> None: + pa_schema = PyarrowFieldParser.from_paimon_schema(self.table.fields) + batch = pyarrow.RecordBatch.from_pandas(df, schema=pa_schema) + self.write_arrow_batch(batch) + + def _write_single_batch( + self, + data: pyarrow.RecordBatch, + partition_spec: dict, + ) -> None: + if data.num_rows == 0: + return + location = self.table.location() + partition_only_value = self._partition_only_value + part_path = _partition_path( + partition_spec, + self.table.partition_keys, + partition_only_value, + ) + if part_path: + dir_path = f"{location}/{part_path}" + else: + dir_path = location + # When overwrite: clear partition dir only once per write session + overwrite_this = ( + self._overwrite + and dir_path not in self._overwritten_dirs + and self.table.file_io.exists(dir_path) + ) + if overwrite_this: + should_delete = ( + not self._static_partitions + or all( + str(partition_spec.get(k)) == str(v) + for k, v in self._static_partitions.items() + ) + ) + if should_delete: + from pypaimon.table.format.format_table_commit import ( + _delete_data_files_in_path, + ) + _delete_data_files_in_path(self.table.file_io, dir_path) + self._overwritten_dirs.add(dir_path) + self.table.file_io.check_or_mkdirs(dir_path) + file_name = f"{self._data_file_prefix}{uuid.uuid4().hex}{self._suffix}" + path = f"{dir_path}/{file_name}" + + fmt = self._file_format + tbl = pyarrow.Table.from_batches([data]) + if fmt == Format.PARQUET: + buf = io.BytesIO() + pyarrow.parquet.write_table(tbl, buf, compression="zstd") + raw = buf.getvalue() + elif fmt == Format.CSV: + if hasattr(pyarrow, "csv"): + buf = io.BytesIO() + pyarrow.csv.write_csv(tbl, buf) + raw = buf.getvalue() + else: + buf = io.StringIO() + tbl.to_pandas().to_csv(buf, index=False) + raw = buf.getvalue().encode("utf-8") + elif fmt == Format.JSON: + import json + lines = [] + for i in range(tbl.num_rows): + row = { + tbl.column_names[j]: tbl.column(j)[i].as_py() + for j in range(tbl.num_columns) + } + lines.append(json.dumps(row) + "\n") + raw = "".join(lines).encode("utf-8") + elif fmt == Format.ORC: + if hasattr(pyarrow, "orc"): + buf = io.BytesIO() + pyarrow.orc.write_table(tbl, buf) + raw = buf.getvalue() + else: + raise ValueError( + "Format table write for ORC requires PyArrow with ORC " + "support (pyarrow.orc)" + ) + elif fmt == Format.TEXT: + partition_keys = self.table.partition_keys + if partition_keys: + data_cols = [ + c for c in tbl.column_names if c not in partition_keys + ] + tbl = tbl.select(data_cols) + pa_f0 = tbl.schema.field(0).type + if tbl.num_columns != 1 or not pyarrow.types.is_string(pa_f0): + raise ValueError( + "TEXT format only supports a single string column, " + f"got {tbl.num_columns} columns" + ) + line_delimiter = self.table.options().get( + "text.line-delimiter", "\n" + ) + lines = [] + col = tbl.column(0) + for i in range(tbl.num_rows): + val = col[i] + py_val = val.as_py() if hasattr(val, "as_py") else val + line = "" if py_val is None else str(py_val) + lines.append(line + line_delimiter) + raw = "".join(lines).encode("utf-8") + else: + raise ValueError(f"Format table write not implemented for {fmt}") + + with self.table.file_io.new_output_stream(path) as out: + out.write(raw) + + self._written_paths.append(path) + + def prepare_commit(self) -> List[FormatTableCommitMessage]: + return [ + FormatTableCommitMessage( + written_paths=list(self._written_paths) + ) + ] + + def close(self) -> None: + pass diff --git a/paimon-python/pypaimon/tests/rest/rest_format_table_test.py b/paimon-python/pypaimon/tests/rest/rest_format_table_test.py new file mode 100644 index 000000000000..d081ae6cd044 --- /dev/null +++ b/paimon-python/pypaimon/tests/rest/rest_format_table_test.py @@ -0,0 +1,613 @@ +""" +Licensed to the Apache Software Foundation (ASF) under one +or more contributor license agreements. See the NOTICE file +distributed with this work for additional information +regarding copyright ownership. The ASF licenses this file +to you under the Apache License, Version 2.0 (the +"License"); you may not use this file except in compliance +with the License. You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" +import unittest + +import pandas as pd +import pyarrow as pa +from parameterized import parameterized + +from pypaimon import Schema +from pypaimon.catalog.catalog_exception import TableNotExistException +from pypaimon.table.format import FormatTable +from pypaimon.tests.rest.rest_base_test import RESTBaseTest + + +def _format_table_read_write_formats(): + formats = [("parquet",), ("csv",), ("json",)] + if hasattr(pa, "orc"): + formats.append(("orc",)) + return formats + + +class RESTFormatTableTest(RESTBaseTest): + + @parameterized.expand(_format_table_read_write_formats()) + def test_format_table_read_write(self, file_format): + pa_schema = pa.schema([ + ("a", pa.int32()), + ("b", pa.int32()), + ("c", pa.int32()), + ]) + schema = Schema.from_pyarrow_schema( + pa_schema, + options={"type": "format-table", "file.format": file_format}, + ) + table_name = f"default.format_table_rw_{file_format}" + try: + self.rest_catalog.drop_table(table_name, True) + except Exception: + pass + self.rest_catalog.create_table(table_name, schema, False) + table = self.rest_catalog.get_table(table_name) + self.assertIsInstance(table, FormatTable) + self.assertEqual(table.format().value, file_format) + opts = table.options() + self.assertIsInstance(opts, dict) + self.assertEqual(opts.get("file.format"), file_format) + + write_builder = table.new_batch_write_builder() + table_write = write_builder.new_write() + table_commit = write_builder.new_commit() + df = pd.DataFrame({ + "a": [10, 10], + "b": [1, 2], + "c": [1, 2], + }) + table_write.write_pandas(df) + table_commit.commit(table_write.prepare_commit()) + table_write.close() + table_commit.close() + + read_builder = table.new_read_builder() + splits = read_builder.new_scan().plan().splits() + table_read = read_builder.new_read() + actual = table_read.to_pandas(splits).sort_values(by="b").reset_index(drop=True) + expected = pa.Table.from_pydict( + {"a": [10, 10], "b": [1, 2], "c": [1, 2]}, + schema=pa_schema, + ).to_pandas() + for col in expected.columns: + if col in actual.columns and actual[col].dtype != expected[col].dtype: + actual[col] = actual[col].astype(expected[col].dtype) + pd.testing.assert_frame_equal(actual, expected) + + def test_format_table_text_read_write(self): + pa_schema = pa.schema([("value", pa.string())]) + schema = Schema.from_pyarrow_schema( + pa_schema, + options={"type": "format-table", "file.format": "text"}, + ) + table_name = "default.format_table_rw_text" + try: + self.rest_catalog.drop_table(table_name, True) + except Exception: + pass + self.rest_catalog.create_table(table_name, schema, False) + table = self.rest_catalog.get_table(table_name) + self.assertIsInstance(table, FormatTable) + self.assertEqual(table.format().value, "text") + opts = table.options() + self.assertIsInstance(opts, dict) + self.assertEqual(opts.get("file.format"), "text") + + write_builder = table.new_batch_write_builder() + table_write = write_builder.new_write() + table_commit = write_builder.new_commit() + df = pd.DataFrame({"value": ["hello", "world"]}) + table_write.write_pandas(df) + table_commit.commit(table_write.prepare_commit()) + table_write.close() + table_commit.close() + + read_builder = table.new_read_builder() + splits = read_builder.new_scan().plan().splits() + table_read = read_builder.new_read() + actual = table_read.to_pandas(splits).sort_values(by="value").reset_index(drop=True) + expected = pd.DataFrame({"value": ["hello", "world"]}) + pd.testing.assert_frame_equal(actual, expected) + + def test_format_table_text_read_write_with_nulls(self): + pa_schema = pa.schema([("value", pa.string())]) + schema = Schema.from_pyarrow_schema( + pa_schema, + options={"type": "format-table", "file.format": "text"}, + ) + table_name = "default.format_table_rw_text_nulls" + try: + self.rest_catalog.drop_table(table_name, True) + except Exception: + pass + self.rest_catalog.create_table(table_name, schema, False) + table = self.rest_catalog.get_table(table_name) + write_builder = table.new_batch_write_builder() + table_write = write_builder.new_write() + table_commit = write_builder.new_commit() + df = pd.DataFrame({"value": ["hello", None, "world"]}) + table_write.write_pandas(df) + table_commit.commit(table_write.prepare_commit()) + table_write.close() + table_commit.close() + + read_builder = table.new_read_builder() + splits = read_builder.new_scan().plan().splits() + table_read = read_builder.new_read() + actual = table_read.to_pandas(splits) + self.assertEqual(actual.shape[0], 3) + # Nulls are written as empty string; read back as "" + self.assertEqual(set(actual["value"].fillna("").astype(str)), {"", "hello", "world"}) + self.assertIn("", actual["value"].values) + + def test_format_table_text_partitioned_read_write(self): + pa_schema = pa.schema([ + ("value", pa.string()), + ("dt", pa.int32()), + ]) + schema = Schema.from_pyarrow_schema( + pa_schema, + partition_keys=["dt"], + options={"type": "format-table", "file.format": "text"}, + ) + table_name = "default.format_table_rw_text_partitioned" + try: + self.rest_catalog.drop_table(table_name, True) + except Exception: + pass + self.rest_catalog.create_table(table_name, schema, False) + table = self.rest_catalog.get_table(table_name) + self.assertIsInstance(table, FormatTable) + self.assertEqual(table.format().value, "text") + + write_builder = table.new_batch_write_builder() + tw = write_builder.new_write() + tc = write_builder.new_commit() + tw.write_pandas(pd.DataFrame({"value": ["a", "b"], "dt": [1, 1]})) + tw.write_pandas(pd.DataFrame({"value": ["c"], "dt": [2]})) + tc.commit(tw.prepare_commit()) + tw.close() + tc.close() + + read_builder = table.new_read_builder() + splits = read_builder.new_scan().plan().splits() + actual = read_builder.new_read().to_pandas(splits).sort_values(by=["dt", "value"]).reset_index(drop=True) + self.assertEqual(actual.shape[0], 3) + self.assertEqual(actual["value"].tolist(), ["a", "b", "c"]) + self.assertEqual(actual["dt"].tolist(), [1, 1, 2]) + + def test_format_table_read_with_limit_to_iterator(self): + pa_schema = pa.schema([ + ("a", pa.int32()), + ("b", pa.int32()), + ]) + schema = Schema.from_pyarrow_schema( + pa_schema, + options={"type": "format-table", "file.format": "parquet"}, + ) + table_name = "default.format_table_limit_iterator" + try: + self.rest_catalog.drop_table(table_name, True) + except Exception: + pass + self.rest_catalog.create_table(table_name, schema, False) + table = self.rest_catalog.get_table(table_name) + write_builder = table.new_batch_write_builder() + tw = write_builder.new_write() + tc = write_builder.new_commit() + tw.write_pandas(pd.DataFrame({"a": [1, 2, 3, 4], "b": [10, 20, 30, 40]})) + tc.commit(tw.prepare_commit()) + tw.close() + tc.close() + + splits = table.new_read_builder().new_scan().plan().splits() + limit = 2 + read_builder = table.new_read_builder().with_limit(limit) + table_read = read_builder.new_read() + + df = table_read.to_pandas(splits) + self.assertEqual(len(df), limit, "to_pandas must respect with_limit(2)") + + batches = list(table_read.to_iterator(splits)) + self.assertEqual(len(batches), limit, "to_iterator must respect with_limit(2)") + + @parameterized.expand(_format_table_read_write_formats()) + def test_format_table_partitioned_overwrite(self, file_format): + pa_schema = pa.schema([ + ("a", pa.int32()), + ("b", pa.int32()), + ("c", pa.int32()), + ]) + schema = Schema.from_pyarrow_schema( + pa_schema, + partition_keys=["c"], + options={"type": "format-table", "file.format": file_format}, + ) + table_name = f"default.format_table_partitioned_overwrite_{file_format}" + self.rest_catalog.drop_table(table_name, True) + self.rest_catalog.create_table(table_name, schema, False) + table = self.rest_catalog.get_table(table_name) + + write_builder = table.new_batch_write_builder() + tw = write_builder.new_write() + tc = write_builder.new_commit() + tw.write_pandas(pd.DataFrame({"a": [10, 10], "b": [10, 20], "c": [1, 1]})) + tc.commit(tw.prepare_commit()) + tw.close() + tc.close() + + tw = table.new_batch_write_builder().overwrite({"c": 1}).new_write() + tc = table.new_batch_write_builder().overwrite({"c": 1}).new_commit() + tw.write_pandas(pd.DataFrame({"a": [12, 12], "b": [100, 200], "c": [1, 1]})) + tc.commit(tw.prepare_commit()) + tw.close() + tc.close() + + read_builder = table.new_read_builder() + splits = read_builder.new_scan().plan().splits() + actual = read_builder.new_read().to_pandas(splits).sort_values(by="b") + self.assertEqual(len(actual), 2) + self.assertEqual(actual["b"].tolist(), [100, 200]) + + def test_format_table_overwrite_only_specified_partition(self): + pa_schema = pa.schema([ + ("a", pa.int32()), + ("b", pa.int32()), + ("c", pa.int32()), + ]) + schema = Schema.from_pyarrow_schema( + pa_schema, + partition_keys=["c"], + options={"type": "format-table", "file.format": "parquet"}, + ) + table_name = "default.format_table_overwrite_one_partition" + try: + self.rest_catalog.drop_table(table_name, True) + except Exception: + pass + self.rest_catalog.create_table(table_name, schema, False) + table = self.rest_catalog.get_table(table_name) + + wb = table.new_batch_write_builder() + tw = wb.new_write() + tc = wb.new_commit() + tw.write_pandas(pd.DataFrame({"a": [10, 10], "b": [10, 20], "c": [1, 1]})) + tw.write_pandas(pd.DataFrame({"a": [30, 30], "b": [30, 40], "c": [2, 2]})) + tc.commit(tw.prepare_commit()) + tw.close() + tc.close() + + tw = table.new_batch_write_builder().overwrite({"c": 1}).new_write() + tc = table.new_batch_write_builder().overwrite({"c": 1}).new_commit() + tw.write_pandas(pd.DataFrame({"a": [12, 12], "b": [100, 200], "c": [1, 1]})) + tc.commit(tw.prepare_commit()) + tw.close() + tc.close() + + actual = table.new_read_builder().new_read().to_pandas( + table.new_read_builder().new_scan().plan().splits() + ).sort_values(by=["c", "b"]) + self.assertEqual(len(actual), 4) + self.assertEqual(actual["b"].tolist(), [100, 200, 30, 40]) + self.assertEqual(actual["c"].tolist(), [1, 1, 2, 2]) + c1 = actual[actual["c"] == 1]["b"].tolist() + c2 = actual[actual["c"] == 2]["b"].tolist() + self.assertEqual(c1, [100, 200], "partition c=1 must be overwritten") + self.assertEqual(c2, [30, 40], "partition c=2 must be unchanged") + + def test_format_table_overwrite_multiple_batches_same_partition(self): + pa_schema = pa.schema([ + ("a", pa.int32()), + ("b", pa.int32()), + ]) + schema = Schema.from_pyarrow_schema( + pa_schema, + options={"type": "format-table", "file.format": "parquet"}, + ) + table_name = "default.format_table_overwrite_multi_batch" + try: + self.rest_catalog.drop_table(table_name, True) + except Exception: + pass + self.rest_catalog.create_table(table_name, schema, False) + table = self.rest_catalog.get_table(table_name) + + wb = table.new_batch_write_builder() + tw = wb.new_write() + tc = wb.new_commit() + tw.write_pandas(pd.DataFrame({"a": [1, 2], "b": [10, 20]})) + tc.commit(tw.prepare_commit()) + tw.close() + tc.close() + + tw = wb.overwrite().new_write() + tc = wb.overwrite().new_commit() + tw.write_pandas(pd.DataFrame({"a": [3, 4], "b": [30, 40]})) + tw.write_pandas(pd.DataFrame({"a": [5, 6], "b": [50, 60]})) + tc.commit(tw.prepare_commit()) + tw.close() + tc.close() + + actual = table.new_read_builder().new_read().to_pandas( + table.new_read_builder().new_scan().plan().splits() + ).sort_values(by="b") + self.assertEqual(len(actual), 4, "overwrite + 2 write_pandas same partition must keep all 4 rows") + self.assertEqual(actual["b"].tolist(), [30, 40, 50, 60]) + + @parameterized.expand(_format_table_read_write_formats()) + def test_format_table_partitioned_read_write(self, file_format): + pa_schema = pa.schema([ + ("a", pa.int32()), + ("b", pa.int32()), + ("dt", pa.int32()), + ]) + schema = Schema.from_pyarrow_schema( + pa_schema, + partition_keys=["dt"], + options={"type": "format-table", "file.format": file_format}, + ) + table_name = f"default.format_table_partitioned_rw_{file_format}" + self.rest_catalog.drop_table(table_name, True) + self.rest_catalog.create_table(table_name, schema, False) + table = self.rest_catalog.get_table(table_name) + self.assertIsInstance(table, FormatTable) + + wb = table.new_batch_write_builder() + tw = wb.new_write() + tc = wb.new_commit() + tw.write_pandas(pd.DataFrame({"a": [1, 2], "b": [10, 20], "dt": [10, 10]})) + tc.commit(tw.prepare_commit()) + tw.close() + tc.close() + + tw = wb.new_write() + tc = wb.new_commit() + tw.write_pandas(pd.DataFrame({"a": [3, 4], "b": [30, 40], "dt": [11, 11]})) + tc.commit(tw.prepare_commit()) + tw.close() + tc.close() + + rb = table.new_read_builder() + splits_all = rb.new_scan().plan().splits() + actual_all = rb.new_read().to_pandas(splits_all).sort_values(by="b") + self.assertEqual(len(actual_all), 4) + self.assertEqual(sorted(actual_all["b"].tolist()), [10, 20, 30, 40]) + + rb_dt10 = table.new_read_builder().with_partition_filter({"dt": "10"}) + splits_dt10 = rb_dt10.new_scan().plan().splits() + actual_dt10 = rb_dt10.new_read().to_pandas(splits_dt10).sort_values(by="b") + self.assertEqual(len(actual_dt10), 2) + self.assertEqual(actual_dt10["b"].tolist(), [10, 20]) + # Partition column must match schema type (int32), not string + self.assertEqual(actual_dt10["dt"].tolist(), [10, 10]) + self.assertEqual(actual_all["dt"].tolist(), [10, 10, 11, 11]) + + def test_format_table_partition_column_returns_schema_type(self): + """Partition columns must be returned with schema type (e.g. int32), not always string.""" + pa_schema = pa.schema([ + ("a", pa.int32()), + ("b", pa.int32()), + ("dt", pa.int32()), + ]) + schema = Schema.from_pyarrow_schema( + pa_schema, + partition_keys=["dt"], + options={"type": "format-table", "file.format": "parquet"}, + ) + table_name = "default.format_table_partition_schema_type" + try: + self.rest_catalog.drop_table(table_name, True) + except Exception: + pass + self.rest_catalog.create_table(table_name, schema, False) + table = self.rest_catalog.get_table(table_name) + wb = table.new_batch_write_builder() + tw = wb.new_write() + tc = wb.new_commit() + tw.write_pandas(pd.DataFrame({"a": [1, 2], "b": [10, 20], "dt": [1, 1]})) + tw.write_pandas(pd.DataFrame({"a": [3, 4], "b": [30, 40], "dt": [2, 2]})) + tc.commit(tw.prepare_commit()) + tw.close() + tc.close() + + rb = table.new_read_builder().with_partition_filter({"dt": "1"}) + splits = rb.new_scan().plan().splits() + actual = rb.new_read().to_pandas(splits).sort_values(by="b") + self.assertEqual(len(actual), 2) + self.assertEqual(actual["b"].tolist(), [10, 20]) + # Must be int list, not string list; fails if partition column is hardcoded as string + self.assertEqual(actual["dt"].tolist(), [1, 1]) + self.assertTrue( + actual["dt"].dtype in (pd.Int32Dtype(), "int32", "int64"), + "dt must be int type per schema, not string", + ) + + def test_format_table_with_filter_extracts_partition_like_java(self): + pa_schema = pa.schema([ + ("a", pa.int32()), + ("b", pa.int32()), + ("dt", pa.int32()), + ]) + schema = Schema.from_pyarrow_schema( + pa_schema, + partition_keys=["dt"], + options={"type": "format-table", "file.format": "parquet"}, + ) + table_name = "default.format_table_with_filter_assert" + try: + self.rest_catalog.drop_table(table_name, True) + except Exception: + pass + self.rest_catalog.create_table(table_name, schema, False) + table = self.rest_catalog.get_table(table_name) + wb = table.new_batch_write_builder() + tw = wb.new_write() + tc = wb.new_commit() + tw.write_pandas(pd.DataFrame({"a": [1, 2], "b": [10, 20], "dt": [10, 10]})) + tw.write_pandas(pd.DataFrame({"a": [3, 4], "b": [30, 40], "dt": [11, 11]})) + tc.commit(tw.prepare_commit()) + tw.close() + tc.close() + + predicate_eq_dt10 = table.new_read_builder().new_predicate_builder().equal("dt", 10) + splits_by_partition_filter = ( + table.new_read_builder().with_partition_filter({"dt": "10"}).new_scan().plan().splits() + ) + splits_by_with_filter = ( + table.new_read_builder().with_filter(predicate_eq_dt10).new_scan().plan().splits() + ) + self.assertEqual( + len(splits_by_with_filter), len(splits_by_partition_filter), + "with_filter(partition equality) must behave like with_partition_filter (Java-aligned)", + ) + actual_from_filter = ( + table.new_read_builder().with_filter(predicate_eq_dt10).new_read().to_pandas(splits_by_with_filter) + ) + self.assertEqual(len(actual_from_filter), 2) + self.assertEqual(actual_from_filter["b"].tolist(), [10, 20]) + + splits_partition_then_filter = ( + table.new_read_builder() + .with_partition_filter({"dt": "10"}) + .with_filter(predicate_eq_dt10) + .new_scan() + .plan() + .splits() + ) + self.assertEqual( + len(splits_partition_then_filter), len(splits_by_partition_filter), + "with_filter must not overwrite a previously set partition filter", + ) + actual = ( + table.new_read_builder() + .with_partition_filter({"dt": "10"}) + .with_filter(predicate_eq_dt10) + .new_read() + .to_pandas(splits_partition_then_filter) + ) + self.assertEqual(len(actual), 2) + self.assertEqual(actual["b"].tolist(), [10, 20]) + + predicate_non_partition = table.new_read_builder().new_predicate_builder().equal("a", 1) + splits_no_filter = table.new_read_builder().new_scan().plan().splits() + splits_with_non_partition_predicate = ( + table.new_read_builder().with_filter(predicate_non_partition).new_scan().plan().splits() + ) + self.assertEqual( + len(splits_with_non_partition_predicate), len(splits_no_filter), + "with_filter(non-partition predicate) must not change scan when no partition spec extracted", + ) + + @parameterized.expand(_format_table_read_write_formats()) + def test_format_table_full_overwrite(self, file_format): + pa_schema = pa.schema([ + ("a", pa.int32()), + ("b", pa.int32()), + ]) + schema = Schema.from_pyarrow_schema( + pa_schema, + options={"type": "format-table", "file.format": file_format}, + ) + table_name = f"default.format_table_full_overwrite_{file_format}" + self.rest_catalog.drop_table(table_name, True) + self.rest_catalog.create_table(table_name, schema, False) + table = self.rest_catalog.get_table(table_name) + + wb = table.new_batch_write_builder() + tw = wb.new_write() + tc = wb.new_commit() + tw.write_pandas(pd.DataFrame({"a": [1, 2], "b": [10, 20]})) + tc.commit(tw.prepare_commit()) + tw.close() + tc.close() + + tw = wb.overwrite().new_write() + tc = wb.overwrite().new_commit() + tw.write_pandas(pd.DataFrame({"a": [3], "b": [30]})) + tc.commit(tw.prepare_commit()) + tw.close() + tc.close() + + rb = table.new_read_builder() + splits = rb.new_scan().plan().splits() + actual = rb.new_read().to_pandas(splits) + self.assertEqual(len(actual), 1) + self.assertEqual(actual["b"].tolist(), [30]) + + @parameterized.expand(_format_table_read_write_formats()) + def test_format_table_split_read(self, file_format): + pa_schema = pa.schema([ + ("id", pa.int32()), + ("name", pa.string()), + ("score", pa.float64()), + ]) + schema = Schema.from_pyarrow_schema( + pa_schema, + options={ + "type": "format-table", + "file.format": file_format, + "source.split.target-size": "54", + }, + ) + table_name = f"default.format_table_split_read_{file_format}" + self.rest_catalog.drop_table(table_name, True) + self.rest_catalog.create_table(table_name, schema, False) + table = self.rest_catalog.get_table(table_name) + + size = 50 + for i in range(0, size, 10): + batch = pd.DataFrame({ + "id": list(range(i, min(i + 10, size))), + "name": [f"User{j}" for j in range(i, min(i + 10, size))], + "score": [85.5 + (j % 15) for j in range(i, min(i + 10, size))], + }) + wb = table.new_batch_write_builder() + tw = wb.new_write() + tc = wb.new_commit() + tw.write_pandas(batch) + tc.commit(tw.prepare_commit()) + tw.close() + tc.close() + + rb = table.new_read_builder() + splits = rb.new_scan().plan().splits() + actual = rb.new_read().to_pandas(splits).sort_values(by="id") + self.assertEqual(len(actual), size) + self.assertEqual(actual["id"].tolist(), list(range(size))) + + @parameterized.expand(_format_table_read_write_formats()) + def test_format_table_catalog(self, file_format): + pa_schema = pa.schema([ + ("str", pa.string()), + ("int", pa.int32()), + ]) + schema = Schema.from_pyarrow_schema( + pa_schema, + options={"type": "format-table", "file.format": file_format}, + ) + table_name = f"default.format_table_catalog_{file_format}" + self.rest_catalog.drop_table(table_name, True) + self.rest_catalog.create_table(table_name, schema, False) + self.assertIn(f"format_table_catalog_{file_format}", self.rest_catalog.list_tables("default")) + table = self.rest_catalog.get_table(table_name) + self.assertIsInstance(table, FormatTable) + + self.rest_catalog.drop_table(table_name, False) + with self.assertRaises(TableNotExistException): + self.rest_catalog.get_table(table_name) + + +if __name__ == "__main__": + unittest.main()