From 18c27f28df9847e75d3ef50aacb61a27101c64aa Mon Sep 17 00:00:00 2001 From: geruh Date: Wed, 24 Dec 2025 20:04:41 -0800 Subject: [PATCH] feat: Add support for rest scan planning --- pyiceberg/catalog/rest/__init__.py | 119 ++++- pyiceberg/catalog/rest/scan_planning.py | 9 +- pyiceberg/exceptions.py | 4 + pyiceberg/table/__init__.py | 128 +++++- tests/catalog/test_scan_planning_models.py | 422 +++++++++++------- .../test_rest_scan_planning_integration.py | 346 ++++++++++++++ 6 files changed, 870 insertions(+), 158 deletions(-) create mode 100644 tests/integration/test_rest_scan_planning_integration.py diff --git a/pyiceberg/catalog/rest/__init__.py b/pyiceberg/catalog/rest/__init__.py index 7866f5e8bd..a6993ac078 100644 --- a/pyiceberg/catalog/rest/__init__.py +++ b/pyiceberg/catalog/rest/__init__.py @@ -14,6 +14,7 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. +from collections import deque from enum import Enum from typing import ( TYPE_CHECKING, @@ -21,7 +22,7 @@ Union, ) -from pydantic import ConfigDict, Field, field_validator +from pydantic import ConfigDict, Field, TypeAdapter, field_validator from requests import HTTPError, Session from tenacity import RetryCallState, retry, retry_if_exception_type, stop_after_attempt @@ -36,6 +37,16 @@ ) from pyiceberg.catalog.rest.auth import AuthManager, AuthManagerAdapter, AuthManagerFactory, LegacyOAuth2AuthManager from pyiceberg.catalog.rest.response import _handle_non_200_response +from pyiceberg.catalog.rest.scan_planning import ( + FetchScanTasksRequest, + PlanCancelled, + PlanCompleted, + PlanFailed, + PlanningResponse, + PlanSubmitted, + PlanTableScanRequest, + ScanTasks, +) from pyiceberg.exceptions import ( AuthorizationExpiredError, CommitFailedException, @@ -44,6 +55,7 @@ NamespaceNotEmptyError, NoSuchIdentifierError, NoSuchNamespaceError, + NoSuchPlanTaskError, NoSuchTableError, NoSuchViewError, TableAlreadyExistsError, @@ -56,6 +68,7 @@ CommitTableRequest, CommitTableResponse, CreateTableTransaction, + FileScanTask, StagedTable, Table, TableIdentifier, @@ -315,6 +328,9 @@ class ListViewsResponse(IcebergBaseModel): identifiers: list[ListViewResponseEntry] = Field() +_PLANNING_RESPONSE_ADAPTER = TypeAdapter(PlanningResponse) + + class RestCatalog(Catalog): uri: str _session: Session @@ -384,6 +400,107 @@ def is_rest_scan_planning_enabled(self) -> bool: self.properties, REST_SCAN_PLANNING_ENABLED, REST_SCAN_PLANNING_ENABLED_DEFAULT ) + @retry(**_RETRY_ARGS) + def _plan_table_scan(self, identifier: str | Identifier, request: PlanTableScanRequest) -> PlanningResponse: + """Submit a scan plan request to the REST server. + + Args: + identifier: Table identifier. + request: The scan plan request parameters. + + Returns: + PlanningResponse the result of the scan plan request representing the status + Raises: + NoSuchTableError: If a table with the given identifier does not exist. + """ + self._check_endpoint(Capability.V1_SUBMIT_TABLE_SCAN_PLAN) + response = self._session.post( + self.url(Endpoints.plan_table_scan, prefixed=True, **self._split_identifier_for_path(identifier)), + data=request.model_dump_json(by_alias=True, exclude_none=True).encode(UTF8), + ) + try: + response.raise_for_status() + except HTTPError as exc: + _handle_non_200_response(exc, {404: NoSuchTableError}) + + return _PLANNING_RESPONSE_ADAPTER.validate_json(response.text) + + @retry(**_RETRY_ARGS) + def _fetch_scan_tasks(self, identifier: str | Identifier, plan_task: str) -> ScanTasks: + """Fetch additional scan tasks using a plan task token. + + Args: + identifier: Table identifier. + plan_task: The plan task token from a previous response. + + Returns: + ScanTasks containing file scan tasks and possibly more plan-task tokens. + + Raises: + NoSuchPlanTaskError: If a plan task with the given identifier or task does not exist. + """ + self._check_endpoint(Capability.V1_TABLE_SCAN_PLAN_TASKS) + request = FetchScanTasksRequest(plan_task=plan_task) + response = self._session.post( + self.url(Endpoints.fetch_scan_tasks, prefixed=True, **self._split_identifier_for_path(identifier)), + data=request.model_dump_json(by_alias=True).encode(UTF8), + ) + try: + response.raise_for_status() + except HTTPError as exc: + _handle_non_200_response(exc, {404: NoSuchPlanTaskError}) + + return ScanTasks.model_validate_json(response.text) + + def plan_scan(self, identifier: str | Identifier, request: PlanTableScanRequest) -> list[FileScanTask]: + """Plan a table scan and return FileScanTasks. + + Handles the full scan planning lifecycle including pagination. + + Args: + identifier: Table identifier. + request: The scan plan request parameters. + + Returns: + List of FileScanTask objects ready for execution. + + Raises: + RuntimeError: If planning fails, is cancelled, or returns unexpected response. + NotImplementedError: If async planning is required but not yet supported. + """ + response = self._plan_table_scan(identifier, request) + + if isinstance(response, PlanFailed): + error_msg = response.error.message if response.error else "unknown error" + raise RuntimeError(f"Received status: failed: {error_msg}") + + if isinstance(response, PlanCancelled): + raise RuntimeError("Received status: cancelled") + + if isinstance(response, PlanSubmitted): + # TODO: implement polling for async planning + raise NotImplementedError(f"Async scan planning not yet supported for planId: {response.plan_id}") + + if not isinstance(response, PlanCompleted): + raise RuntimeError(f"Invalid planStatus for response: {type(response).__name__}") + + tasks: list[FileScanTask] = [] + + # Collect tasks from initial response + for task in response.file_scan_tasks: + tasks.append(FileScanTask.from_rest_response(task, response.delete_files)) + + # Fetch and collect from additional batches + pending_tasks = deque(response.plan_tasks) + while pending_tasks: + plan_task = pending_tasks.popleft() + batch = self._fetch_scan_tasks(identifier, plan_task) + for task in batch.file_scan_tasks: + tasks.append(FileScanTask.from_rest_response(task, batch.delete_files)) + pending_tasks.extend(batch.plan_tasks) + + return tasks + def _create_legacy_oauth2_auth_manager(self, session: Session) -> AuthManager: """Create the LegacyOAuth2AuthManager by fetching required properties. diff --git a/pyiceberg/catalog/rest/scan_planning.py b/pyiceberg/catalog/rest/scan_planning.py index 1a2204341b..c9131fbb2b 100644 --- a/pyiceberg/catalog/rest/scan_planning.py +++ b/pyiceberg/catalog/rest/scan_planning.py @@ -25,9 +25,16 @@ from pyiceberg.catalog.rest.response import ErrorResponseMessage from pyiceberg.expressions import BooleanExpression, SerializableBooleanExpression -from pyiceberg.manifest import FileFormat +from pyiceberg.manifest import DataFileContent, FileFormat from pyiceberg.typedef import IcebergBaseModel +# REST content-type to DataFileContent +CONTENT_TYPE_MAP: dict[str, DataFileContent] = { + "data": DataFileContent.DATA, + "position-deletes": DataFileContent.POSITION_DELETES, + "equality-deletes": DataFileContent.EQUALITY_DELETES, +} + # Primitive types that can appear in partition values and bounds PrimitiveTypeValue: TypeAlias = bool | int | float | str | Decimal | UUID | date | time | datetime | bytes diff --git a/pyiceberg/exceptions.py b/pyiceberg/exceptions.py index c80f104e46..e755c73095 100644 --- a/pyiceberg/exceptions.py +++ b/pyiceberg/exceptions.py @@ -52,6 +52,10 @@ class NoSuchNamespaceError(Exception): """Raised when a referenced name-space is not found.""" +class NoSuchPlanTaskError(Exception): + """Raised when a scan plan task is not found.""" + + class RESTError(Exception): """Raises when there is an unknown response from the REST Catalog.""" diff --git a/pyiceberg/table/__init__.py b/pyiceberg/table/__init__.py index 2e26a4ccc2..285b10435a 100644 --- a/pyiceberg/table/__init__.py +++ b/pyiceberg/table/__init__.py @@ -145,6 +145,11 @@ from pyiceberg_core.datafusion import IcebergDataFusionTable from pyiceberg.catalog import Catalog + from pyiceberg.catalog.rest.scan_planning import ( + RESTContentFile, + RESTDeleteFile, + RESTFileScanTask, + ) ALWAYS_TRUE = AlwaysTrue() DOWNCAST_NS_TIMESTAMP_TO_US_ON_WRITE = "downcast-ns-timestamp-to-us-on-write" @@ -1168,6 +1173,8 @@ def scan( snapshot_id=snapshot_id, options=options, limit=limit, + catalog=self.catalog, + table_identifier=self._identifier, ) @property @@ -1684,6 +1691,8 @@ class TableScan(ABC): snapshot_id: int | None options: Properties limit: int | None + catalog: Catalog | None + table_identifier: Identifier | None def __init__( self, @@ -1695,6 +1704,8 @@ def __init__( snapshot_id: int | None = None, options: Properties = EMPTY_DICT, limit: int | None = None, + catalog: Catalog | None = None, + table_identifier: Identifier | None = None, ): self.table_metadata = table_metadata self.io = io @@ -1704,6 +1715,8 @@ def __init__( self.snapshot_id = snapshot_id self.options = options self.limit = limit + self.catalog = catalog + self.table_identifier = table_identifier def snapshot(self) -> Snapshot | None: if self.snapshot_id: @@ -1798,6 +1811,74 @@ def __init__( self.delete_files = delete_files or set() self.residual = residual + @staticmethod + def from_rest_response( + rest_task: RESTFileScanTask, + delete_files: list[RESTDeleteFile], + ) -> FileScanTask: + """Convert a RESTFileScanTask to a FileScanTask. + + Args: + rest_task: The REST file scan task. + delete_files: The list of delete files from the ScanTasks response. + + Returns: + A FileScanTask with the converted data and delete files. + + Raises: + NotImplementedError: If equality delete files are encountered. + """ + from pyiceberg.catalog.rest.scan_planning import RESTEqualityDeleteFile + + data_file = _rest_file_to_data_file(rest_task.data_file) + + resolved_deletes: set[DataFile] = set() + if rest_task.delete_file_references: + for idx in rest_task.delete_file_references: + delete_file = delete_files[idx] + if isinstance(delete_file, RESTEqualityDeleteFile): + raise NotImplementedError(f"PyIceberg does not yet support equality deletes: {delete_file.file_path}") + resolved_deletes.add(_rest_file_to_data_file(delete_file)) + + return FileScanTask( + data_file=data_file, + delete_files=resolved_deletes, + residual=rest_task.residual_filter if rest_task.residual_filter else ALWAYS_TRUE, + ) + + +def _rest_file_to_data_file(rest_file: RESTContentFile) -> DataFile: + """Convert a REST content file to a manifest DataFile.""" + from pyiceberg.catalog.rest.scan_planning import CONTENT_TYPE_MAP, RESTDataFile + + if isinstance(rest_file, RESTDataFile): + column_sizes = rest_file.column_sizes.to_dict() if rest_file.column_sizes else None + value_counts = rest_file.value_counts.to_dict() if rest_file.value_counts else None + null_value_counts = rest_file.null_value_counts.to_dict() if rest_file.null_value_counts else None + nan_value_counts = rest_file.nan_value_counts.to_dict() if rest_file.nan_value_counts else None + else: + column_sizes = None + value_counts = None + null_value_counts = None + nan_value_counts = None + + data_file = DataFile.from_args( + content=CONTENT_TYPE_MAP[rest_file.content], + file_path=rest_file.file_path, + file_format=rest_file.file_format, + partition=Record(*rest_file.partition) if rest_file.partition else Record(), + record_count=rest_file.record_count, + file_size_in_bytes=rest_file.file_size_in_bytes, + column_sizes=column_sizes, + value_counts=value_counts, + null_value_counts=null_value_counts, + nan_value_counts=nan_value_counts, + split_offsets=rest_file.split_offsets, + sort_order_id=rest_file.sort_order_id, + ) + data_file.spec_id = rest_file.spec_id + return data_file + def _open_manifest( io: FileIO, @@ -1970,12 +2051,35 @@ def scan_plan_helper(self) -> Iterator[list[ManifestEntry]]: ], ) - def plan_files(self) -> Iterable[FileScanTask]: - """Plans the relevant files by filtering on the PartitionSpecs. + def _should_use_rest_planning(self) -> bool: + """Check if REST scan planning should be used for this scan.""" + from pyiceberg.catalog.rest import RestCatalog + + if not isinstance(self.catalog, RestCatalog): + return False + return self.catalog.is_rest_scan_planning_enabled() + + def _plan_files_rest(self) -> Iterable[FileScanTask]: + """Plan files using REST server-side scan planning.""" + from pyiceberg.catalog.rest import RestCatalog + from pyiceberg.catalog.rest.scan_planning import PlanTableScanRequest + + if not isinstance(self.catalog, RestCatalog): + raise TypeError("REST scan planning requires a RestCatalog") + if self.table_identifier is None: + raise ValueError("REST scan planning requires a table identifier") + + request = PlanTableScanRequest( + snapshot_id=self.snapshot_id, + select=list(self.selected_fields) if self.selected_fields != ("*",) else None, + filter=self.row_filter if self.row_filter != ALWAYS_TRUE else None, + case_sensitive=self.case_sensitive, + ) - Returns: - List of FileScanTasks that contain both data and delete files. - """ + return self.catalog.plan_scan(self.table_identifier, request) + + def _plan_files_local(self) -> Iterable[FileScanTask]: + """Plan files locally by reading manifests.""" data_entries: list[ManifestEntry] = [] positional_delete_entries = SortedList(key=lambda entry: entry.sequence_number or INITIAL_SEQUENCE_NUMBER) @@ -2006,6 +2110,20 @@ def plan_files(self) -> Iterable[FileScanTask]: for data_entry in data_entries ] + def plan_files(self) -> Iterable[FileScanTask]: + """Plans the relevant files by filtering on the PartitionSpecs. + + If the table comes from a REST catalog with scan planning enabled, + this will use server-side scan planning. Otherwise, it falls back + to local planning. + + Returns: + List of FileScanTasks that contain both data and delete files. + """ + if self._should_use_rest_planning(): + return self._plan_files_rest() + return self._plan_files_local() + def to_arrow(self) -> pa.Table: """Read an Arrow table eagerly from this DataScan. diff --git a/tests/catalog/test_scan_planning_models.py b/tests/catalog/test_scan_planning_models.py index 9f03c8f7cd..567f1444a7 100644 --- a/tests/catalog/test_scan_planning_models.py +++ b/tests/catalog/test_scan_planning_models.py @@ -18,7 +18,9 @@ import pytest from pydantic import TypeAdapter, ValidationError +from requests_mock import Mocker +from pyiceberg.catalog.rest import RestCatalog from pyiceberg.catalog.rest.scan_planning import ( CountMap, FetchScanTasksRequest, @@ -33,12 +35,92 @@ RESTFileScanTask, RESTPositionDeleteFile, ScanTasks, - StorageCredential, ValueMap, ) from pyiceberg.expressions import AlwaysTrue, EqualTo, Reference from pyiceberg.manifest import FileFormat +TEST_URI = "https://iceberg-test-catalog/" + + +@pytest.fixture +def rest_scan_catalog(requests_mock: Mocker) -> RestCatalog: + requests_mock.get( + f"{TEST_URI}v1/config", + json={ + "defaults": {"rest-scan-planning-enabled": "true"}, + "overrides": {}, + "endpoints": [ + "POST /v1/{prefix}/namespaces/{namespace}/tables/{table}/plan", + "POST /v1/{prefix}/namespaces/{namespace}/tables/{table}/tasks", + ], + }, + status_code=200, + ) + + return RestCatalog( + "test", + uri=TEST_URI, + **{"rest-scan-planning-enabled": "true"}, + ) + + +def _rest_data_file( + *, + file_path: str = "s3://bucket/table/data/file.parquet", + file_format: str = "parquet", + file_size_in_bytes: int = 1024, + record_count: int = 100, +) -> dict[str, Any]: + return { + "spec-id": 0, + "content": "data", + "file-path": file_path, + "file-format": file_format, + "file-size-in-bytes": file_size_in_bytes, + "record-count": record_count, + } + + +def _rest_position_delete_file( + *, + file_path: str = "s3://bucket/table/delete.parquet", + file_format: str = "parquet", + file_size_in_bytes: int = 256, + record_count: int = 5, + content_offset: int = 100, + content_size_in_bytes: int = 200, +) -> dict[str, Any]: + return { + "spec-id": 0, + "content": "position-deletes", + "file-path": file_path, + "file-format": file_format, + "file-size-in-bytes": file_size_in_bytes, + "record-count": record_count, + "content-offset": content_offset, + "content-size-in-bytes": content_size_in_bytes, + } + + +def _rest_equality_delete_file( + *, + file_path: str = "s3://bucket/table/eq-delete.parquet", + equality_ids: list[int], + file_format: str = "parquet", + file_size_in_bytes: int = 256, + record_count: int = 5, +) -> dict[str, Any]: + return { + "spec-id": 0, + "content": "equality-deletes", + "file-path": file_path, + "file-format": file_format, + "file-size-in-bytes": file_size_in_bytes, + "record-count": record_count, + "equality-ids": equality_ids, + } + def test_count_map_valid() -> None: cm = CountMap(keys=[1, 2, 3], values=[100, 200, 300]) @@ -62,80 +144,45 @@ def test_value_map_mixed_types() -> None: def test_data_file_parsing() -> None: - data = { - "spec-id": 0, - "content": "data", - "file-path": "s3://bucket/table/file.parquet", - "file-format": "parquet", - "file-size-in-bytes": 1024, - "record-count": 100, - } - df = RESTDataFile.model_validate(data) + data_file = _rest_data_file(file_path="s3://bucket/table/file.parquet") + df = RESTDataFile.model_validate(data_file) assert df.content == "data" assert df.file_path == "s3://bucket/table/file.parquet" assert df.file_format == FileFormat.PARQUET - assert df.file_size_in_bytes == 1024 def test_data_file_with_stats() -> None: - data = { - "spec-id": 0, - "content": "data", - "file-path": "s3://bucket/table/file.parquet", - "file-format": "parquet", - "file-size-in-bytes": 1024, - "record-count": 100, + data_file = _rest_data_file() + + data_file_with_stats = { + **data_file, "column-sizes": {"keys": [1, 2], "values": [500, 524]}, "value-counts": {"keys": [1, 2], "values": [100, 100]}, } - df = RESTDataFile.model_validate(data) + df = RESTDataFile.model_validate(data_file_with_stats) assert df.column_sizes is not None assert df.column_sizes.to_dict() == {1: 500, 2: 524} def test_position_delete_file() -> None: - data = { - "spec-id": 0, - "content": "position-deletes", - "file-path": "s3://bucket/table/delete.parquet", - "file-format": "parquet", - "file-size-in-bytes": 512, - "record-count": 10, - "content-offset": 100, - "content-size-in-bytes": 200, - } - pdf = RESTPositionDeleteFile.model_validate(data) + delete_file = _rest_position_delete_file(file_path="s3://bucket/table/delete.puffin", file_format="puffin") + pdf = RESTPositionDeleteFile.model_validate(delete_file) assert pdf.content == "position-deletes" assert pdf.content_offset == 100 assert pdf.content_size_in_bytes == 200 def test_equality_delete_file() -> None: - data = { - "spec-id": 0, - "content": "equality-deletes", - "file-path": "s3://bucket/table/eq-delete.parquet", - "file-format": "parquet", - "file-size-in-bytes": 256, - "record-count": 5, - "equality-ids": [1, 2], - } - edf = RESTEqualityDeleteFile.model_validate(data) - assert edf.content == "equality-deletes" - assert edf.equality_ids == [1, 2] + delete_file = _rest_equality_delete_file(equality_ids=[1, 2]) + equality_delete = RESTEqualityDeleteFile.model_validate(delete_file) + assert equality_delete.content == "equality-deletes" + assert equality_delete.equality_ids == [1, 2] def test_file_format_case_insensitive() -> None: for fmt in ["parquet", "PARQUET", "Parquet"]: - data = { - "spec-id": 0, - "content": "data", - "file-path": "/path", - "file-format": fmt, - "file-size-in-bytes": 100, - "record-count": 10, - } - df = RESTDataFile.model_validate(data) + data_file = _rest_data_file(file_format=fmt) + df = RESTDataFile.model_validate(data_file) assert df.file_format == FileFormat.PARQUET @@ -148,56 +195,27 @@ def test_file_format_case_insensitive() -> None: ], ) def test_file_formats(format_str: str, expected: FileFormat) -> None: - data = { - "spec-id": 0, - "content": "data", - "file-path": f"s3://bucket/table/path/file.{format_str}", - "file-format": format_str, - "file-size-in-bytes": 1024, - "record-count": 100, - } - df = RESTDataFile.model_validate(data) + data_file = _rest_data_file(file_format=format_str) + df = RESTDataFile.model_validate(data_file) assert df.file_format == expected def test_delete_file_discriminator_position() -> None: - data = { - "spec-id": 0, - "content": "position-deletes", - "file-path": "s3://bucket/table/delete.parquet", - "file-format": "parquet", - "file-size-in-bytes": 256, - "record-count": 5, - } - result = TypeAdapter(RESTDeleteFile).validate_python(data) + delete_file = _rest_position_delete_file() + result = TypeAdapter(RESTDeleteFile).validate_python(delete_file) assert isinstance(result, RESTPositionDeleteFile) def test_delete_file_discriminator_equality() -> None: - data = { - "spec-id": 0, - "content": "equality-deletes", - "file-path": "s3://bucket/table/delete.parquet", - "file-format": "parquet", - "file-size-in-bytes": 256, - "record-count": 5, - "equality-ids": [1], - } - result = TypeAdapter(RESTDeleteFile).validate_python(data) + delete_file = _rest_equality_delete_file(equality_ids=[1, 2]) + result = TypeAdapter(RESTDeleteFile).validate_python(delete_file) assert isinstance(result, RESTEqualityDeleteFile) def test_basic_scan_task() -> None: - data = { - "data-file": { - "spec-id": 0, - "content": "data", - "file-path": "s3://bucket/table/file.parquet", - "file-format": "parquet", - "file-size-in-bytes": 1024, - "record-count": 100, - } - } + data_file = _rest_data_file(file_path="s3://bucket/table/file.parquet") + + data = {"data-file": data_file} task = RESTFileScanTask.model_validate(data) assert task.data_file.file_path == "s3://bucket/table/file.parquet" assert task.delete_file_references is None @@ -205,15 +223,9 @@ def test_basic_scan_task() -> None: def test_scan_task_with_delete_references() -> None: + data_file = _rest_data_file() data = { - "data-file": { - "spec-id": 0, - "content": "data", - "file-path": "s3://bucket/table/file.parquet", - "file-format": "parquet", - "file-size-in-bytes": 1024, - "record-count": 100, - }, + "data-file": data_file, "delete-file-references": [0, 1, 2], } task = RESTFileScanTask.model_validate(data) @@ -221,15 +233,9 @@ def test_scan_task_with_delete_references() -> None: def test_scan_task_with_residual_filter_true() -> None: + data_file = _rest_data_file() data = { - "data-file": { - "spec-id": 0, - "content": "data", - "file-path": "s3://bucket/table/file.parquet", - "file-format": "parquet", - "file-size-in-bytes": 1024, - "record-count": 100, - }, + "data-file": data_file, "residual-filter": True, } task = RESTFileScanTask.model_validate(data) @@ -249,27 +255,13 @@ def test_empty_scan_tasks() -> None: def test_scan_tasks_with_files() -> None: + data_file = _rest_data_file(file_path="s3://bucket/table/data.parquet") + delete_file = _rest_position_delete_file() data = { - "delete-files": [ - { - "spec-id": 0, - "content": "position-deletes", - "file-path": "s3://bucket/table/delete.parquet", - "file-format": "parquet", - "file-size-in-bytes": 256, - "record-count": 5, - } - ], + "delete-files": [delete_file], "file-scan-tasks": [ { - "data-file": { - "spec-id": 0, - "content": "data", - "file-path": "s3://bucket/table/data.parquet", - "file-format": "parquet", - "file-size-in-bytes": 1024, - "record-count": 100, - }, + "data-file": data_file, "delete-file-references": [0], } ], @@ -282,18 +274,12 @@ def test_scan_tasks_with_files() -> None: def test_invalid_delete_file_reference() -> None: + data_file = _rest_data_file(file_path="s3://bucket/table/data.parquet") data = { "delete-files": [], "file-scan-tasks": [ { - "data-file": { - "spec-id": 0, - "content": "data", - "file-path": "s3://bucket/table/data.parquet", - "file-format": "parquet", - "file-size-in-bytes": 1024, - "record-count": 100, - }, + "data-file": data_file, "delete-file-references": [0], } ], @@ -305,17 +291,9 @@ def test_invalid_delete_file_reference() -> None: def test_delete_files_require_file_scan_tasks() -> None: + delete_file = _rest_position_delete_file() data = { - "delete-files": [ - { - "spec-id": 0, - "content": "position-deletes", - "file-path": "s3://bucket/table/delete.parquet", - "file-format": "parquet", - "file-size-in-bytes": 256, - "record-count": 5, - } - ], + "delete-files": [delete_file], "file-scan-tasks": [], "plan-tasks": [], } @@ -437,14 +415,156 @@ def test_cancelled_response() -> None: assert isinstance(result, PlanCancelled) -def test_storage_credential_parsing() -> None: - data = { - "prefix": "s3://bucket/path/", - "config": { - "s3.access-key-id": "key", - "s3.secret-access-key": "secret", +def test_plan_scan_completed_single_batch(rest_scan_catalog: RestCatalog, requests_mock: Mocker) -> None: + file_one = _rest_data_file(file_path="s3://bucket/tbl/data/file1.parquet") + + requests_mock.post( + f"{TEST_URI}v1/namespaces/db/tables/tbl/plan", + json={ + "status": "completed", + "plan-id": "plan-123", + "delete-files": [], + "file-scan-tasks": [{"data-file": file_one}], + "plan-tasks": [], }, - } - cred = StorageCredential.model_validate(data) - assert cred.prefix == "s3://bucket/path/" - assert cred.config["s3.access-key-id"] == "key" + status_code=200, + ) + + request = PlanTableScanRequest() + tasks = list(rest_scan_catalog.plan_scan(("db", "tbl"), request)) + + assert len(tasks) == 1 + assert tasks[0].file.file_path == "s3://bucket/tbl/data/file1.parquet" + + +def test_plan_scan_with_pagination(rest_scan_catalog: RestCatalog, requests_mock: Mocker) -> None: + file_one = _rest_data_file(file_path="s3://bucket/tbl/data/file1.parquet") + file_two = _rest_data_file(file_path="s3://bucket/tbl/data/file2.parquet") + + requests_mock.post( + f"{TEST_URI}v1/namespaces/db/tables/tbl/plan", + json={ + "status": "completed", + "plan-id": "plan-123", + "delete-files": [], + "file-scan-tasks": [{"data-file": file_one}], + "plan-tasks": ["token-batch-2"], + }, + status_code=200, + ) + + requests_mock.post( + f"{TEST_URI}v1/namespaces/db/tables/tbl/tasks", + json={ + "delete-files": [], + "file-scan-tasks": [{"data-file": file_two}], + "plan-tasks": [], + }, + status_code=200, + ) + + request = PlanTableScanRequest() + + tasks = list(rest_scan_catalog.plan_scan(("db", "tbl"), request)) + + assert len(tasks) == 2 + assert tasks[0].file.file_path == "s3://bucket/tbl/data/file1.parquet" + assert tasks[1].file.file_path == "s3://bucket/tbl/data/file2.parquet" + + +def test_plan_scan_with_delete_files(rest_scan_catalog: RestCatalog, requests_mock: Mocker) -> None: + file_one = _rest_data_file(file_path="s3://bucket/tbl/data/file1.parquet") + delete_file = _rest_position_delete_file() + requests_mock.post( + f"{TEST_URI}v1/namespaces/db/tables/tbl/plan", + json={ + "status": "completed", + "plan-id": "plan-123", + "delete-files": [delete_file], + "file-scan-tasks": [ + { + "data-file": file_one, + "delete-file-references": [0], + } + ], + "plan-tasks": [], + }, + status_code=200, + ) + + request = PlanTableScanRequest() + tasks = list(rest_scan_catalog.plan_scan(("db", "tbl"), request)) + + assert len(tasks) == 1 + assert tasks[0].file.file_path == "s3://bucket/tbl/data/file1.parquet" + assert len(tasks[0].delete_files) == 1 + + +def test_plan_scan_async_not_supported(rest_scan_catalog: RestCatalog, requests_mock: Mocker) -> None: + requests_mock.post( + f"{TEST_URI}v1/namespaces/db/tables/tbl/plan", + json={ + "status": "submitted", + "plan-id": "plan-456", + }, + status_code=200, + ) + + request = PlanTableScanRequest() + with pytest.raises(NotImplementedError, match="Async scan planning not yet supported"): + list(rest_scan_catalog.plan_scan(("db", "tbl"), request)) + + +def test_plan_scan_empty_result(rest_scan_catalog: RestCatalog, requests_mock: Mocker) -> None: + requests_mock.post( + f"{TEST_URI}v1/namespaces/db/tables/tbl/plan", + json={ + "status": "completed", + "plan-id": "plan-123", + "delete-files": [], + "file-scan-tasks": [], + "plan-tasks": [], + }, + status_code=200, + ) + + request = PlanTableScanRequest() + tasks = list(rest_scan_catalog.plan_scan(("db", "tbl"), request)) + assert len(tasks) == 0 + + +def test_plan_scan_cancelled(rest_scan_catalog: RestCatalog, requests_mock: Mocker) -> None: + requests_mock.post( + f"{TEST_URI}v1/namespaces/db/tables/tbl/plan", + json={"status": "cancelled"}, + status_code=200, + ) + + request = PlanTableScanRequest() + with pytest.raises(RuntimeError, match="Received status: cancelled"): + list(rest_scan_catalog.plan_scan(("db", "tbl"), request)) + + +def test_plan_scan_equality_deletes_not_supported(rest_scan_catalog: RestCatalog, requests_mock: Mocker) -> None: + file_one = _rest_data_file(file_path="s3://bucket/tbl/data/file1.parquet") + equality_delete = _rest_equality_delete_file(equality_ids=[1, 2]) + requests_mock.post( + f"{TEST_URI}v1/namespaces/db/tables/tbl/plan", + json={ + "status": "completed", + "plan-id": "plan-123", + "delete-files": [equality_delete], + "file-scan-tasks": [ + { + "data-file": file_one, + "delete-file-references": [0], + } + ], + "plan-tasks": [], + }, + status_code=200, + ) + + request = PlanTableScanRequest() + with pytest.raises(NotImplementedError, match="PyIceberg does not yet support equality deletes"): + list(rest_scan_catalog.plan_scan(("db", "tbl"), request)) diff --git a/tests/integration/test_rest_scan_planning_integration.py b/tests/integration/test_rest_scan_planning_integration.py new file mode 100644 index 0000000000..456dbe41a6 --- /dev/null +++ b/tests/integration/test_rest_scan_planning_integration.py @@ -0,0 +1,346 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from datetime import date, datetime, time, timedelta, timezone +from decimal import Decimal +from typing import Any +from uuid import uuid4 + +import pyarrow as pa +import pytest +from pyspark.sql import SparkSession + +from pyiceberg.catalog import Catalog, load_catalog +from pyiceberg.catalog.rest import RestCatalog +from pyiceberg.exceptions import NoSuchTableError +from pyiceberg.expressions import ( + And, + BooleanExpression, + EqualTo, + GreaterThan, + GreaterThanOrEqual, + In, + IsNull, + LessThan, + LessThanOrEqual, + Not, + NotEqualTo, + NotIn, + NotNull, + Or, + StartsWith, +) +from pyiceberg.partitioning import PartitionField, PartitionSpec +from pyiceberg.schema import Schema +from pyiceberg.table import ALWAYS_TRUE, Table +from pyiceberg.transforms import ( + IdentityTransform, +) +from pyiceberg.types import ( + BinaryType, + BooleanType, + DateType, + DecimalType, + DoubleType, + FixedType, + LongType, + NestedField, + StringType, + TimestampType, + TimestamptzType, + TimeType, + UUIDType, +) + + +@pytest.fixture(scope="session") +def scan_catalog() -> Catalog: + catalog = load_catalog( + "local", + **{ + "type": "rest", + "uri": "http://localhost:8181", + "s3.endpoint": "http://localhost:9000", + "s3.access-key-id": "admin", + "s3.secret-access-key": "password", + "rest-scan-planning-enabled": "true", + }, + ) + catalog.create_namespace_if_not_exists("default") + return catalog + + +def recreate_table(catalog: Catalog, identifier: str, **kwargs: Any) -> Table: + """Drop table if exists and create a new one.""" + try: + catalog.drop_table(identifier) + except NoSuchTableError: + pass + return catalog.create_table(identifier, **kwargs) + + +def _assert_remote_scan_matches_local_scan( + rest_table: Table, + session_catalog: Catalog, + identifier: str, + row_filter: BooleanExpression = ALWAYS_TRUE, +) -> None: + rest_tasks = list(rest_table.scan(row_filter=row_filter).plan_files()) + rest_paths = {task.file.file_path for task in rest_tasks} + + local_table = session_catalog.load_table(identifier) + local_tasks = list(local_table.scan(row_filter=row_filter).plan_files()) + local_paths = {task.file.file_path for task in local_tasks} + + assert rest_paths == local_paths + + +@pytest.mark.integration +def test_rest_scan_matches_local(scan_catalog: RestCatalog, session_catalog: Catalog) -> None: + identifier = "default.test_rest_scan" + + table = recreate_table( + scan_catalog, + identifier, + schema=Schema( + NestedField(1, "id", LongType()), + NestedField(2, "data", StringType()), + NestedField(3, "num", LongType()), + ), + ) + table.append(pa.Table.from_pydict({"id": [1, 2, 3], "data": ["a", "b", "c"], "num": [10, 20, 30]})) + table.append(pa.Table.from_pydict({"id": [4, 5, 6], "data": ["d", "e", "f"], "num": [40, 50, 60]})) + + try: + _assert_remote_scan_matches_local_scan(table, session_catalog, identifier) + finally: + scan_catalog.drop_table(identifier) + + +@pytest.mark.integration +def test_rest_scan_with_filter(scan_catalog: RestCatalog, session_catalog: Catalog) -> None: + identifier = "default.test_rest_scan_filter" + + table = recreate_table( + scan_catalog, + identifier, + schema=Schema( + NestedField(1, "id", LongType()), + NestedField(2, "data", LongType()), + ), + ) + table.append(pa.Table.from_pydict({"id": [1, 2, 3], "data": [10, 20, 30]})) + + try: + _assert_remote_scan_matches_local_scan( + table, + session_catalog, + identifier, + row_filter=And(GreaterThan("data", 5), LessThan("data", 25)), + ) + + _assert_remote_scan_matches_local_scan( + table, + session_catalog, + identifier, + row_filter=EqualTo("id", 1), + ) + finally: + scan_catalog.drop_table(identifier) + + +@pytest.mark.integration +def test_rest_scan_with_deletes(spark: SparkSession, scan_catalog: RestCatalog, session_catalog: Catalog) -> None: + identifier = "default.test_rest_scan_deletes" + + spark.sql(f"DROP TABLE IF EXISTS {identifier}") + spark.sql(f""" + CREATE TABLE {identifier} (id bigint, data bigint) + USING iceberg + TBLPROPERTIES( + 'format-version' = 2, + 'write.delete.mode'='merge-on-read' + ) + """) + spark.sql(f"INSERT INTO {identifier} VALUES (1, 10), (2, 20), (3, 30)") + spark.sql(f"DELETE FROM {identifier} WHERE id = 2") + + try: + rest_table = scan_catalog.load_table(identifier) + rest_tasks = list(rest_table.scan().plan_files()) + rest_paths = {task.file.file_path for task in rest_tasks} + rest_delete_paths = {delete.file_path for task in rest_tasks for delete in task.delete_files} + + local_table = session_catalog.load_table(identifier) + local_tasks = list(local_table.scan().plan_files()) + local_paths = {task.file.file_path for task in local_tasks} + local_delete_paths = {delete.file_path for task in local_tasks for delete in task.delete_files} + + assert rest_paths == local_paths + assert rest_delete_paths == local_delete_paths + finally: + spark.sql(f"DROP TABLE IF EXISTS {identifier}") + + +@pytest.mark.integration +def test_rest_scan_with_partitioning(scan_catalog: RestCatalog, session_catalog: Catalog) -> None: + identifier = "default.test_rest_scan_partitioned" + + schema = Schema( + NestedField(1, "id", LongType()), + NestedField(2, "category", StringType()), + NestedField(3, "data", LongType()), + ) + partition_spec = PartitionSpec(PartitionField(2, 1000, IdentityTransform(), "category")) + + table = recreate_table(scan_catalog, identifier, schema=schema, partition_spec=partition_spec) + + table.append(pa.Table.from_pydict({"id": [1, 2], "category": ["a", "a"], "data": [10, 20]})) + table.append(pa.Table.from_pydict({"id": [3, 4], "category": ["b", "b"], "data": [30, 40]})) + + try: + _assert_remote_scan_matches_local_scan(table, session_catalog, identifier) + + # test filter against partition + _assert_remote_scan_matches_local_scan( + table, + session_catalog, + identifier, + row_filter=EqualTo("category", "a"), + ) + finally: + scan_catalog.drop_table(identifier) + + +@pytest.mark.integration +def test_rest_scan_primitive_types(scan_catalog: RestCatalog, session_catalog: Catalog) -> None: + identifier = "default.test_primitives" + + schema = Schema( + NestedField(1, "bool_col", BooleanType()), + NestedField(2, "long_col", LongType()), + NestedField(3, "double_col", DoubleType()), + NestedField(4, "decimal_col", DecimalType(10, 2)), + NestedField(5, "string_col", StringType()), + NestedField(6, "date_col", DateType()), + NestedField(7, "time_col", TimeType()), + NestedField(8, "timestamp_col", TimestampType()), + NestedField(9, "timestamptz_col", TimestamptzType()), + NestedField(10, "uuid_col", UUIDType()), + NestedField(11, "fixed_col", FixedType(16)), + NestedField(12, "binary_col", BinaryType()), + ) + + table = recreate_table(scan_catalog, identifier, schema=schema) + + now = datetime.now() + now_tz = datetime.now(tz=timezone.utc) + today = date.today() + uuid1, uuid2, uuid3 = uuid4(), uuid4(), uuid4() + + arrow_table = pa.Table.from_pydict( + { + "bool_col": [True, False, True], + "long_col": [100, 200, 300], + "double_col": [1.11, 2.22, 3.33], + "decimal_col": [Decimal("1.23"), Decimal("4.56"), Decimal("7.89")], + "string_col": ["a", "b", "c"], + "date_col": [today, today - timedelta(days=1), today - timedelta(days=2)], + "time_col": [time(8, 30, 0), time(12, 0, 0), time(18, 45, 30)], + "timestamp_col": [now, now - timedelta(hours=1), now - timedelta(hours=2)], + "timestamptz_col": [now_tz, now_tz - timedelta(hours=1), now_tz - timedelta(hours=2)], + "uuid_col": [uuid1.bytes, uuid2.bytes, uuid3.bytes], + "fixed_col": [b"0123456789abcdef", b"abcdef0123456789", b"fedcba9876543210"], + "binary_col": [b"hello", b"world", b"test"], + }, + schema=schema.as_arrow(), + ) + table.append(arrow_table) + + try: + _assert_remote_scan_matches_local_scan(table, session_catalog, identifier) + finally: + scan_catalog.drop_table(identifier) + + +@pytest.mark.integration +def test_rest_scan_with_filters(scan_catalog: RestCatalog, session_catalog: Catalog) -> None: + identifier = "default.test_complex_filters" + + schema = Schema( + NestedField(1, "id", LongType()), + NestedField(2, "name", StringType()), + NestedField(3, "value", LongType()), + NestedField(4, "optional", StringType(), required=False), + ) + + table = recreate_table(scan_catalog, identifier, schema=schema) + + table.append( + pa.Table.from_pydict( + { + "id": list(range(1, 21)), + "name": [f"item_{i}" for i in range(1, 21)], + "value": [i * 100 for i in range(1, 21)], + "optional": [None if i % 3 == 0 else f"opt_{i}" for i in range(1, 21)], + } + ) + ) + + try: + filters = [ + EqualTo("id", 10), + NotEqualTo("id", 10), + GreaterThan("value", 1000), + GreaterThanOrEqual("value", 1000), + LessThan("value", 500), + LessThanOrEqual("value", 500), + In("id", [1, 5, 10, 15]), + NotIn("id", [1, 5, 10, 15]), + IsNull("optional"), + NotNull("optional"), + StartsWith("name", "item_1"), + And(GreaterThan("id", 5), LessThan("id", 15)), + Or(EqualTo("id", 1), EqualTo("id", 20)), + Not(EqualTo("id", 10)), + ] + + for filter_expr in filters: + _assert_remote_scan_matches_local_scan(table, session_catalog, identifier, row_filter=filter_expr) + finally: + scan_catalog.drop_table(identifier) + + +@pytest.mark.integration +def test_rest_scan_empty_table(scan_catalog: RestCatalog, session_catalog: Catalog) -> None: + identifier = "default.test_empty_table" + + schema = Schema( + NestedField(1, "id", LongType()), + NestedField(2, "data", StringType()), + ) + + table = recreate_table(scan_catalog, identifier, schema=schema) + + try: + rest_tasks = list(table.scan().plan_files()) + local_table = session_catalog.load_table(identifier) + local_tasks = list(local_table.scan().plan_files()) + + assert len(rest_tasks) == 0 + assert len(local_tasks) == 0 + finally: + scan_catalog.drop_table(identifier)