diff --git a/mkdocs/docs/geospatial.md b/mkdocs/docs/geospatial.md index f7b8433b49..30aac5d383 100644 --- a/mkdocs/docs/geospatial.md +++ b/mkdocs/docs/geospatial.md @@ -84,11 +84,11 @@ point_wkb = bytes.fromhex("0101000000000000000000000000000000000000") 1. **WKB/WKT Conversion**: Converting between WKB bytes and WKT strings requires external libraries (like Shapely). PyIceberg does not include this conversion to avoid heavy dependencies. -2. **Spatial Predicates**: Spatial filtering (e.g., ST_Contains, ST_Intersects) is not yet supported for query pushdown. +2. **Spatial Predicates Execution**: Spatial predicate APIs (`st-contains`, `st-intersects`, `st-within`, `st-overlaps`) are available in expression trees and binding. Row-level execution and metrics/pushdown evaluation are not implemented yet. -3. **Bounds Metrics**: Geometry/geography columns do not currently contribute to data file bounds metrics. +3. **Without geoarrow-pyarrow**: When the `geoarrow-pyarrow` package is not installed, geometry and geography columns are stored as binary without GeoArrow extension type metadata. The Iceberg schema preserves type information, but other tools reading the Parquet files directly may not recognize them as spatial types. Install with `pip install pyiceberg[geoarrow]` for full GeoArrow support. -4. **Without geoarrow-pyarrow**: When the `geoarrow-pyarrow` package is not installed, geometry and geography columns are stored as binary without GeoArrow extension type metadata. The Iceberg schema preserves type information, but other tools reading the Parquet files directly may not recognize them as spatial types. Install with `pip install pyiceberg[geoarrow]` for full GeoArrow support. +4. **GeoArrow planar ambiguity**: In GeoArrow metadata, `geometry` and `geography(..., 'planar')` can be encoded identically (no explicit edge metadata). PyIceberg resolves this ambiguity at the Arrow/Parquet schema-compatibility boundary by treating them as compatible when CRS matches, while keeping core schema compatibility strict elsewhere. ## Format Version diff --git a/mkdocs/docs/plans/Geospatial_PR_How_To_Review.md b/mkdocs/docs/plans/Geospatial_PR_How_To_Review.md new file mode 100644 index 0000000000..cdde383564 --- /dev/null +++ b/mkdocs/docs/plans/Geospatial_PR_How_To_Review.md @@ -0,0 +1,75 @@ +# How To Review: Geospatial Compatibility, Metrics, and Expressions PR + +## Goal + +This PR is large because it spans expression APIs, Arrow/Parquet conversion, metrics generation, and documentation. +Recommended strategy: review in focused passes by concern, not file order. + +## Recommended Review Order + +1. **Core geospatial utility correctness** + - `pyiceberg/utils/geospatial.py` + - `tests/utils/test_geospatial.py` + - Focus on envelope extraction, antimeridian behavior, and bound encoding formats. + +1. **Metrics integration and write/import paths** + - `pyiceberg/io/pyarrow.py` + - `tests/io/test_pyarrow_stats.py` + - Focus on: + - geospatial bounds derived from row WKB values + - skipping Parquet binary min/max for geospatial columns + - partition inference not using geospatial envelope bounds + +1. **GeoArrow compatibility and ambiguity boundary** + - `pyiceberg/schema.py` + - `pyiceberg/io/pyarrow.py` + - `tests/io/test_pyarrow.py` + - Confirm: + - planar equivalence enabled only at Arrow/Parquet boundary + - spherical mismatch still fails + - fallback warning when GeoArrow dependency is absent + +1. **Spatial expression surface area** + - `pyiceberg/expressions/__init__.py` + - `pyiceberg/expressions/visitors.py` + - `tests/expressions/test_spatial_predicates.py` + - Confirm: + - bind-time type checks (geometry/geography only) + - visitor plumbing is complete + - conservative evaluator behavior is explicit and documented + +1. **User-facing docs** + - `mkdocs/docs/geospatial.md` + - Check limitations and behavior notes match implementation. + +## High-Risk Areas To Inspect Closely + +1. **Boundary scope leakage** + - Ensure planar-equivalence relaxation is not enabled globally. + +2. **Envelope semantics** + - Geography antimeridian cases (`xmin > xmax`) are expected and intentional. + +3. **Metrics correctness** + - Geospatial bounds are serialized envelopes, not raw value min/max. + +4. **Conservative evaluator behavior** + - Spatial predicates should not accidentally become strict in metrics/manifest evaluators. + +## Quick Validation Commands + +```bash +uv run --extra hive --extra bigquery python -m pytest tests/utils/test_geospatial.py -q +uv run --extra hive --extra bigquery python -m pytest tests/io/test_pyarrow_stats.py -k "geospatial or planar_geography_schema or partition_inference_skips_geospatial_bounds" -q +uv run --extra hive --extra bigquery python -m pytest tests/io/test_pyarrow.py -k "geoarrow or planar_geography_geometry_equivalence or spherical_geography_geometry_equivalence or logs_warning_once" -q +uv run --extra hive --extra bigquery python -m pytest tests/expressions/test_spatial_predicates.py tests/expressions/test_visitors.py -k "spatial or translate_column_names" -q +``` + +## Review Outcome Checklist + +1. Geometry/geography bounds are present and correctly encoded for write/import paths. +2. `geometry` vs `geography(planar)` is only equivalent at Arrow/Parquet compatibility boundary with CRS equality. +3. `geography(spherical)` remains incompatible with `geometry`. +4. Spatial predicates are correctly modeled/bound; execution and pushdown remain intentionally unimplemented. +5. Missing GeoArrow dependency degrades gracefully with explicit warning. +6. Docs match implemented behavior and limitations. diff --git a/pyiceberg/expressions/__init__.py b/pyiceberg/expressions/__init__.py index 3910a146c7..019200e0de 100644 --- a/pyiceberg/expressions/__init__.py +++ b/pyiceberg/expressions/__init__.py @@ -17,10 +17,11 @@ from __future__ import annotations +import builtins from abc import ABC, abstractmethod from collections.abc import Callable, Iterable, Sequence from functools import cached_property -from typing import Any, TypeAlias +from typing import Any, TypeAlias, cast from typing import Literal as TypingLiteral from pydantic import ConfigDict, Field, SerializeAsAny, model_validator @@ -29,7 +30,7 @@ from pyiceberg.expressions.literals import AboveMax, BelowMin, Literal, literal from pyiceberg.schema import Accessor, Schema from pyiceberg.typedef import IcebergBaseModel, IcebergRootModel, L, LiteralValue, StructProtocol -from pyiceberg.types import DoubleType, FloatType, NestedField +from pyiceberg.types import DoubleType, FloatType, GeographyType, GeometryType, NestedField from pyiceberg.utils.singleton import Singleton @@ -48,6 +49,16 @@ def _to_literal(value: L | Literal[L]) -> Literal[L]: return literal(value) +def _to_bytes(value: bytes | bytearray | memoryview) -> bytes: + if isinstance(value, bytes): + return value + if isinstance(value, bytearray): + return bytes(value) + if isinstance(value, memoryview): + return value.tobytes() + raise TypeError(f"Expected bytes-like value, got {type(value)}") + + class BooleanExpression(IcebergBaseModel, ABC): """An expression that evaluates to a boolean.""" @@ -109,6 +120,14 @@ def handle_primitive_type(cls, v: Any, handler: ValidatorFunctionWrapHandler) -> return StartsWith(**v) elif field_type == "not-starts-with": return NotStartsWith(**v) + elif field_type == "st-contains": + return STContains(**v) + elif field_type == "st-intersects": + return STIntersects(**v) + elif field_type == "st-within": + return STWithin(**v) + elif field_type == "st-overlaps": + return STOverlaps(**v) # Set elif field_type == "in": @@ -1106,3 +1125,169 @@ def __invert__(self) -> StartsWith: @property def as_bound(self) -> type[BoundNotStartsWith]: # type: ignore return BoundNotStartsWith + + +class SpatialPredicate(UnboundPredicate, ABC): + type: TypingLiteral["st-contains", "st-intersects", "st-within", "st-overlaps"] = Field(alias="type") + term: UnboundTerm + value: bytes = Field() + model_config = ConfigDict(populate_by_name=True, frozen=True, arbitrary_types_allowed=True) + + def __init__( + self, + term: str | UnboundTerm, + geometry: bytes | bytearray | memoryview | None = None, + **kwargs: Any, + ) -> None: + if geometry is None and "value" in kwargs: + geometry = kwargs["value"] + if geometry is None: + raise TypeError("Spatial predicates require WKB bytes") + + super().__init__(term=_to_unbound_term(term), value=_to_bytes(geometry)) + + @property + def geometry(self) -> bytes: + return self.value + + def bind(self, schema: Schema, case_sensitive: bool = True) -> BoundSpatialPredicate: + bound_term = self.term.bind(schema, case_sensitive) + if not isinstance(bound_term.ref().field.field_type, (GeometryType, GeographyType)): + raise TypeError(f"Spatial predicates can only be bound against geometry/geography fields: {bound_term.ref().field}") + bound_cls = cast(Any, self.as_bound) + return bound_cls(bound_term, self.geometry) + + def __eq__(self, other: Any) -> bool: + """Return whether two spatial predicates are equivalent.""" + if isinstance(other, self.__class__): + return self.term == other.term and self.geometry == other.geometry + return False + + def __str__(self) -> str: + """Return a human-readable representation.""" + return f"{str(self.__class__.__name__)}(term={repr(self.term)}, geometry={self.geometry!r})" + + def __repr__(self) -> str: + """Return the debug representation.""" + return f"{str(self.__class__.__name__)}(term={repr(self.term)}, geometry={self.geometry!r})" + + @property + @abstractmethod + def as_bound(self) -> builtins.type[BoundSpatialPredicate]: ... + + +class BoundSpatialPredicate(BoundPredicate, ABC): + value: bytes = Field() + + def __init__(self, term: BoundTerm, geometry: bytes | bytearray | memoryview): + super().__init__(term=term, value=_to_bytes(geometry)) + + @property + def geometry(self) -> bytes: + return self.value + + def __eq__(self, other: Any) -> bool: + """Return whether two bound spatial predicates are equivalent.""" + if isinstance(other, self.__class__): + return self.term == other.term and self.geometry == other.geometry + return False + + def __str__(self) -> str: + """Return a human-readable representation.""" + return f"{self.__class__.__name__}(term={str(self.term)}, geometry={self.geometry!r})" + + def __repr__(self) -> str: + """Return the debug representation.""" + return f"{str(self.__class__.__name__)}(term={repr(self.term)}, geometry={self.geometry!r})" + + @property + @abstractmethod + def as_unbound(self) -> type[SpatialPredicate]: ... + + +class BoundSTContains(BoundSpatialPredicate): + def __invert__(self) -> BooleanExpression: + """Return the negated expression.""" + return Not(child=self) + + @property + def as_unbound(self) -> type[STContains]: + return STContains + + +class BoundSTIntersects(BoundSpatialPredicate): + def __invert__(self) -> BooleanExpression: + """Return the negated expression.""" + return Not(child=self) + + @property + def as_unbound(self) -> type[STIntersects]: + return STIntersects + + +class BoundSTWithin(BoundSpatialPredicate): + def __invert__(self) -> BooleanExpression: + """Return the negated expression.""" + return Not(child=self) + + @property + def as_unbound(self) -> type[STWithin]: + return STWithin + + +class BoundSTOverlaps(BoundSpatialPredicate): + def __invert__(self) -> BooleanExpression: + """Return the negated expression.""" + return Not(child=self) + + @property + def as_unbound(self) -> type[STOverlaps]: + return STOverlaps + + +class STContains(SpatialPredicate): + type: TypingLiteral["st-contains"] = Field(default="st-contains", alias="type") + + def __invert__(self) -> BooleanExpression: + """Return the negated expression.""" + return Not(child=self) + + @property + def as_bound(self) -> builtins.type[BoundSTContains]: + return BoundSTContains + + +class STIntersects(SpatialPredicate): + type: TypingLiteral["st-intersects"] = Field(default="st-intersects", alias="type") + + def __invert__(self) -> BooleanExpression: + """Return the negated expression.""" + return Not(child=self) + + @property + def as_bound(self) -> builtins.type[BoundSTIntersects]: + return BoundSTIntersects + + +class STWithin(SpatialPredicate): + type: TypingLiteral["st-within"] = Field(default="st-within", alias="type") + + def __invert__(self) -> BooleanExpression: + """Return the negated expression.""" + return Not(child=self) + + @property + def as_bound(self) -> builtins.type[BoundSTWithin]: + return BoundSTWithin + + +class STOverlaps(SpatialPredicate): + type: TypingLiteral["st-overlaps"] = Field(default="st-overlaps", alias="type") + + def __invert__(self) -> BooleanExpression: + """Return the negated expression.""" + return Not(child=self) + + @property + def as_bound(self) -> builtins.type[BoundSTOverlaps]: + return BoundSTOverlaps diff --git a/pyiceberg/expressions/visitors.py b/pyiceberg/expressions/visitors.py index 0beb0f3df0..ea87f634a6 100644 --- a/pyiceberg/expressions/visitors.py +++ b/pyiceberg/expressions/visitors.py @@ -47,7 +47,12 @@ BoundNotStartsWith, BoundPredicate, BoundSetPredicate, + BoundSpatialPredicate, BoundStartsWith, + BoundSTContains, + BoundSTIntersects, + BoundSTOverlaps, + BoundSTWithin, BoundTerm, BoundUnaryPredicate, Not, @@ -326,6 +331,18 @@ def visit_starts_with(self, term: BoundTerm, literal: LiteralValue) -> T: def visit_not_starts_with(self, term: BoundTerm, literal: LiteralValue) -> T: """Visit bound NotStartsWith predicate.""" + def visit_st_contains(self, term: BoundTerm, geometry: bytes) -> T: + raise NotImplementedError(f"{self.__class__.__name__} does not implement st-contains") + + def visit_st_intersects(self, term: BoundTerm, geometry: bytes) -> T: + raise NotImplementedError(f"{self.__class__.__name__} does not implement st-intersects") + + def visit_st_within(self, term: BoundTerm, geometry: bytes) -> T: + raise NotImplementedError(f"{self.__class__.__name__} does not implement st-within") + + def visit_st_overlaps(self, term: BoundTerm, geometry: bytes) -> T: + raise NotImplementedError(f"{self.__class__.__name__} does not implement st-overlaps") + def visit_unbound_predicate(self, predicate: UnboundPredicate) -> T: """Visit an unbound predicate. @@ -421,6 +438,26 @@ def _(expr: BoundNotStartsWith, visitor: BoundBooleanExpressionVisitor[T]) -> T: return visitor.visit_not_starts_with(term=expr.term, literal=expr.literal) +@visit_bound_predicate.register(BoundSTContains) +def _(expr: BoundSTContains, visitor: BoundBooleanExpressionVisitor[T]) -> T: + return visitor.visit_st_contains(term=expr.term, geometry=expr.geometry) + + +@visit_bound_predicate.register(BoundSTIntersects) +def _(expr: BoundSTIntersects, visitor: BoundBooleanExpressionVisitor[T]) -> T: + return visitor.visit_st_intersects(term=expr.term, geometry=expr.geometry) + + +@visit_bound_predicate.register(BoundSTWithin) +def _(expr: BoundSTWithin, visitor: BoundBooleanExpressionVisitor[T]) -> T: + return visitor.visit_st_within(term=expr.term, geometry=expr.geometry) + + +@visit_bound_predicate.register(BoundSTOverlaps) +def _(expr: BoundSTOverlaps, visitor: BoundBooleanExpressionVisitor[T]) -> T: + return visitor.visit_st_overlaps(term=expr.term, geometry=expr.geometry) + + def rewrite_not(expr: BooleanExpression) -> BooleanExpression: return visit(expr, _RewriteNotVisitor()) @@ -514,6 +551,18 @@ def visit_starts_with(self, term: BoundTerm, literal: LiteralValue) -> bool: def visit_not_starts_with(self, term: BoundTerm, literal: LiteralValue) -> bool: return not self.visit_starts_with(term, literal) + def visit_st_contains(self, term: BoundTerm, geometry: bytes) -> bool: + raise NotImplementedError("st-contains row-level evaluation is not implemented") + + def visit_st_intersects(self, term: BoundTerm, geometry: bytes) -> bool: + raise NotImplementedError("st-intersects row-level evaluation is not implemented") + + def visit_st_within(self, term: BoundTerm, geometry: bytes) -> bool: + raise NotImplementedError("st-within row-level evaluation is not implemented") + + def visit_st_overlaps(self, term: BoundTerm, geometry: bytes) -> bool: + raise NotImplementedError("st-overlaps row-level evaluation is not implemented") + def visit_true(self) -> bool: return True @@ -762,6 +811,18 @@ def visit_not_starts_with(self, term: BoundTerm, literal: LiteralValue) -> bool: return ROWS_MIGHT_MATCH + def visit_st_contains(self, term: BoundTerm, geometry: bytes) -> bool: + return ROWS_MIGHT_MATCH + + def visit_st_intersects(self, term: BoundTerm, geometry: bytes) -> bool: + return ROWS_MIGHT_MATCH + + def visit_st_within(self, term: BoundTerm, geometry: bytes) -> bool: + return ROWS_MIGHT_MATCH + + def visit_st_overlaps(self, term: BoundTerm, geometry: bytes) -> bool: + return ROWS_MIGHT_MATCH + def visit_true(self) -> bool: return ROWS_MIGHT_MATCH @@ -905,6 +966,8 @@ def visit_bound_predicate(self, predicate: BoundPredicate) -> BooleanExpression: pred = predicate.as_unbound(field.name, predicate.literal) elif isinstance(predicate, BoundSetPredicate): pred = predicate.as_unbound(field.name, predicate.literals) + elif isinstance(predicate, BoundSpatialPredicate): + raise NotImplementedError("Spatial predicate translation is not supported when source columns are missing") else: raise ValueError(f"Unsupported predicate: {predicate}") @@ -926,6 +989,8 @@ def visit_bound_predicate(self, predicate: BoundPredicate) -> BooleanExpression: return predicate.as_unbound(file_column_name, predicate.literal) elif isinstance(predicate, BoundSetPredicate): return predicate.as_unbound(file_column_name, predicate.literals) + elif isinstance(predicate, BoundSpatialPredicate): + return predicate.as_unbound(file_column_name, predicate.geometry) else: raise ValueError(f"Unsupported predicate: {predicate}") @@ -1065,6 +1130,18 @@ def visit_starts_with(self, term: BoundTerm, literal: LiteralValue) -> list[tupl def visit_not_starts_with(self, term: BoundTerm, literal: LiteralValue) -> list[tuple[str, str, Any]]: return [] + def visit_st_contains(self, term: BoundTerm, geometry: bytes) -> list[tuple[str, str, Any]]: + return [] + + def visit_st_intersects(self, term: BoundTerm, geometry: bytes) -> list[tuple[str, str, Any]]: + return [] + + def visit_st_within(self, term: BoundTerm, geometry: bytes) -> list[tuple[str, str, Any]]: + return [] + + def visit_st_overlaps(self, term: BoundTerm, geometry: bytes) -> list[tuple[str, str, Any]]: + return [] + def visit_true(self) -> list[tuple[str, str, Any]]: return [] # Not supported @@ -1153,6 +1230,18 @@ def _is_nan(self, val: Any) -> bool: # In the case of None or other non-numeric types return False + def visit_st_contains(self, term: BoundTerm, geometry: bytes) -> bool: + return ROWS_MIGHT_MATCH + + def visit_st_intersects(self, term: BoundTerm, geometry: bytes) -> bool: + return ROWS_MIGHT_MATCH + + def visit_st_within(self, term: BoundTerm, geometry: bytes) -> bool: + return ROWS_MIGHT_MATCH + + def visit_st_overlaps(self, term: BoundTerm, geometry: bytes) -> bool: + return ROWS_MIGHT_MATCH + class _InclusiveMetricsEvaluator(_MetricsEvaluator): struct: StructType @@ -1739,6 +1828,18 @@ def visit_starts_with(self, term: BoundTerm, literal: LiteralValue) -> bool: def visit_not_starts_with(self, term: BoundTerm, literal: LiteralValue) -> bool: return ROWS_MIGHT_NOT_MATCH + def visit_st_contains(self, term: BoundTerm, geometry: bytes) -> bool: + return ROWS_MIGHT_NOT_MATCH + + def visit_st_intersects(self, term: BoundTerm, geometry: bytes) -> bool: + return ROWS_MIGHT_NOT_MATCH + + def visit_st_within(self, term: BoundTerm, geometry: bytes) -> bool: + return ROWS_MIGHT_NOT_MATCH + + def visit_st_overlaps(self, term: BoundTerm, geometry: bytes) -> bool: + return ROWS_MIGHT_NOT_MATCH + def _get_field(self, field_id: int) -> NestedField: field = self.struct.field(field_id=field_id) if field is None: diff --git a/pyiceberg/io/pyarrow.py b/pyiceberg/io/pyarrow.py index 0dfc5eb55a..6b870b6139 100644 --- a/pyiceberg/io/pyarrow.py +++ b/pyiceberg/io/pyarrow.py @@ -30,6 +30,7 @@ import functools import importlib import itertools +import json import logging import operator import os @@ -180,6 +181,12 @@ from pyiceberg.utils.config import Config from pyiceberg.utils.datetime import millis_to_datetime from pyiceberg.utils.decimal import unscaled_to_decimal +from pyiceberg.utils.geospatial import ( + GeometryEnvelope, + extract_envelope_from_wkb, + merge_envelopes, + serialize_geospatial_bound, +) from pyiceberg.utils.properties import get_first_property_value, property_as_bool, property_as_int from pyiceberg.utils.singleton import Singleton from pyiceberg.utils.truncate import truncate_upper_bound_binary_string, truncate_upper_bound_text_string @@ -207,6 +214,14 @@ T = TypeVar("T") +@lru_cache(maxsize=1) +def _warn_geoarrow_unavailable() -> None: + logger.warning( + "geoarrow-pyarrow is not installed; falling back to binary for geometry/geography columns. " + "Install pyiceberg with the geoarrow extra to preserve GeoArrow metadata in Parquet." + ) + + @lru_cache def _cached_resolve_s3_region(bucket: str) -> str | None: from pyarrow.fs import resolve_s3_region @@ -812,6 +827,7 @@ def visit_geometry(self, geometry_type: GeometryType) -> pa.DataType: return ga.wkb().with_crs(geometry_type.crs) except ImportError: + _warn_geoarrow_unavailable() return pa.large_binary() def visit_geography(self, geography_type: GeographyType) -> pa.DataType: @@ -830,6 +846,7 @@ def visit_geography(self, geography_type: GeographyType) -> pa.DataType: # "planar" is the default edge type in GeoArrow, no need to set explicitly return wkb_type except ImportError: + _warn_geoarrow_unavailable() return pa.large_binary() @@ -1341,6 +1358,51 @@ def _get_field_id(field: pa.Field) -> int | None: return None +def _geoarrow_wkb_to_iceberg(primitive: pa.DataType) -> PrimitiveType | None: + if not isinstance(primitive, pa.ExtensionType) or primitive.extension_name != "geoarrow.wkb": + return None + + # Default CRS in the Iceberg spec for both geometry and geography. + crs = "OGC:CRS84" + + # Avoid conversions that may require optional CRS dependencies. + primitive_crs = getattr(primitive, "crs", None) + raw_crs = getattr(primitive_crs, "_crs", None) + if isinstance(raw_crs, str) and raw_crs: + crs = raw_crs + elif isinstance(primitive_crs, str) and primitive_crs: + crs = primitive_crs + + edges: str | None = None + try: + serialized = primitive.__arrow_ext_serialize__() + if serialized: + payload = json.loads(serialized.decode("utf-8")) + if isinstance(payload, dict): + if isinstance(payload.get("crs"), str) and payload["crs"]: + crs = payload["crs"] + if isinstance(payload.get("edges"), str): + edges = payload["edges"].lower() + except (AttributeError, UnicodeDecodeError, json.JSONDecodeError): + pass + + if edges is None: + edge_type = getattr(primitive, "edge_type", None) + edge_name = getattr(edge_type, "name", None) + if isinstance(edge_name, str) and edge_name.lower() == "spherical": + edges = edge_name.lower() + + if edges == "spherical": + return GeographyType(crs, "spherical") + if edges == "planar": + return GeographyType(crs, "planar") + + # GeoArrow WKB without explicit edge semantics maps best to geometry. + # This is ambiguous with geography(planar); compatibility for that case is handled + # explicitly at the Arrow/Parquet schema-compatibility boundary. + return GeometryType(crs) + + class _HasIds(PyArrowSchemaVisitor[bool]): def schema(self, schema: pa.Schema, struct_result: bool) -> bool: return struct_result @@ -1466,6 +1528,8 @@ def primitive(self, primitive: pa.DataType) -> PrimitiveType: return TimestamptzType() elif primitive.tz is None: return TimestampType() + elif geo_type := _geoarrow_wkb_to_iceberg(primitive): + return geo_type elif pa.types.is_binary(primitive) or pa.types.is_large_binary(primitive) or pa.types.is_binary_view(primitive): return BinaryType() @@ -2248,6 +2312,58 @@ def max_as_bytes(self) -> bytes | None: return self.serialize(self.current_max) +class GeospatialStatsAggregator: + primitive_type: PrimitiveType + current_min: Any + current_max: Any + + def __init__(self, iceberg_type: PrimitiveType) -> None: + if not isinstance(iceberg_type, (GeometryType, GeographyType)): + raise ValueError(f"Expected GeometryType or GeographyType, got {iceberg_type}") + self.primitive_type = iceberg_type + self.current_min = None + self.current_max = None + self._envelope: GeometryEnvelope | None = None + + def update_from_wkb(self, val: bytes | bytearray | memoryview | None) -> None: + if val is None: + return + + envelope = extract_envelope_from_wkb(bytes(val), isinstance(self.primitive_type, GeographyType)) + if envelope is None: + return + + if self._envelope is None: + self._envelope = envelope + else: + self._envelope = merge_envelopes( + self._envelope, + envelope, + is_geography=isinstance(self.primitive_type, GeographyType), + ) + + self.current_min = self._envelope.to_min_bound() + self.current_max = self._envelope.to_max_bound() + + def update_min(self, val: Any | None) -> None: + if isinstance(val, (bytes, bytearray, memoryview)): + self.update_from_wkb(val) + + def update_max(self, val: Any | None) -> None: + if isinstance(val, (bytes, bytearray, memoryview)): + self.update_from_wkb(val) + + def min_as_bytes(self) -> bytes | None: + if self._envelope is None: + return None + return serialize_geospatial_bound(self._envelope.to_min_bound()) + + def max_as_bytes(self) -> bytes | None: + if self._envelope is None: + return None + return serialize_geospatial_bound(self._envelope.to_max_bound()) + + DEFAULT_TRUNCATION_LENGTH = 16 TRUNCATION_EXPR = r"^truncate\((\d+)\)$" @@ -2480,7 +2596,7 @@ class DataFileStatistics: value_counts: dict[int, int] null_value_counts: dict[int, int] nan_value_counts: dict[int, int] - column_aggregates: dict[int, StatsAggregator] + column_aggregates: dict[int, StatsAggregator | GeospatialStatsAggregator] split_offsets: list[int] def _partition_value(self, partition_field: PartitionField, schema: Schema) -> Any: @@ -2488,6 +2604,11 @@ def _partition_value(self, partition_field: PartitionField, schema: Schema) -> A return None source_field = schema.find_field(partition_field.source_id) + if isinstance(source_field.field_type, (GeometryType, GeographyType)): + # Geospatial lower/upper bounds encode envelope extrema, not original values, + # so they cannot be used to infer a partition value. + return None + iceberg_transform = partition_field.transform if not iceberg_transform.preserves_order: @@ -2546,6 +2667,78 @@ def to_serialized_dict(self) -> dict[str, Any]: } +def _iter_wkb_values(column: pa.Array | ChunkedArray) -> Iterator[bytes]: + chunks = column.chunks if isinstance(column, ChunkedArray) else [column] + for chunk in chunks: + if isinstance(chunk, pa.ExtensionArray): + chunk = chunk.storage + + for scalar in chunk: + if not scalar.is_valid: + continue + + value = scalar.as_py() + if isinstance(value, bytes): + yield value + elif isinstance(value, bytearray): + yield bytes(value) + elif isinstance(value, memoryview): + yield value.tobytes() + elif hasattr(value, "to_wkb"): + yield bytes(value.to_wkb()) + else: + raise ValueError(f"Expected a bytes-like WKB value, got {type(value)}") + + +def geospatial_column_aggregates_from_arrow_table( + arrow_table: pa.Table, stats_columns: dict[int, StatisticsCollector] +) -> dict[int, GeospatialStatsAggregator]: + geospatial_aggregates: dict[int, GeospatialStatsAggregator] = {} + + for field_id, stats_col in stats_columns.items(): + if stats_col.mode.type in (MetricModeTypes.NONE, MetricModeTypes.COUNTS): + continue + + if not isinstance(stats_col.iceberg_type, (GeometryType, GeographyType)): + continue + + column = _get_field_from_arrow_table(arrow_table, stats_col.column_name) + aggregator = GeospatialStatsAggregator(stats_col.iceberg_type) + + try: + for value in _iter_wkb_values(column): + aggregator.update_from_wkb(value) + except ValueError as exc: + logger.warning("Skipping geospatial bounds for column %s: %s", stats_col.column_name, exc) + continue + + if aggregator.min_as_bytes() is not None: + geospatial_aggregates[field_id] = aggregator + + return geospatial_aggregates + + +def geospatial_column_aggregates_from_parquet_file( + input_file: InputFile, stats_columns: dict[int, StatisticsCollector] +) -> dict[int, GeospatialStatsAggregator]: + geospatial_stats_columns = { + field_id: stats_col + for field_id, stats_col in stats_columns.items() + if stats_col.mode.type not in (MetricModeTypes.NONE, MetricModeTypes.COUNTS) + and isinstance(stats_col.iceberg_type, (GeometryType, GeographyType)) + } + if not geospatial_stats_columns: + return {} + + with input_file.open() as input_stream: + arrow_table = pq.read_table( + input_stream, + columns=[stats_col.column_name for stats_col in geospatial_stats_columns.values()], + ) + + return geospatial_column_aggregates_from_arrow_table(arrow_table, geospatial_stats_columns) + + def data_file_statistics_from_parquet_metadata( parquet_metadata: pq.FileMetaData, stats_columns: dict[int, StatisticsCollector], @@ -2617,6 +2810,11 @@ def data_file_statistics_from_parquet_metadata( if stats_col.mode == MetricsMode(MetricModeTypes.COUNTS): continue + if isinstance(stats_col.iceberg_type, (GeometryType, GeographyType)): + # Geospatial metrics bounds are computed from row values (WKB parsing), + # not Parquet binary min/max statistics. + continue + if field_id not in col_aggs: try: col_aggs[field_id] = StatsAggregator( @@ -2660,7 +2858,7 @@ def data_file_statistics_from_parquet_metadata( value_counts=value_counts, null_value_counts=null_value_counts, nan_value_counts=nan_value_counts, - column_aggregates=col_aggs, + column_aggregates=cast(dict[int, StatsAggregator | GeospatialStatsAggregator], col_aggs), split_offsets=split_offsets, ) @@ -2707,11 +2905,13 @@ def write_parquet(task: WriteTask) -> DataFile: fos, schema=arrow_table.schema, store_decimal_as_integer=True, **parquet_writer_kwargs ) as writer: writer.write(arrow_table, row_group_size=row_group_size) + stats_columns = compute_statistics_plan(file_schema, table_metadata.properties) statistics = data_file_statistics_from_parquet_metadata( parquet_metadata=writer.writer.metadata, - stats_columns=compute_statistics_plan(file_schema, table_metadata.properties), + stats_columns=stats_columns, parquet_column_mapping=parquet_path_to_id_mapping(file_schema), ) + statistics.column_aggregates.update(geospatial_column_aggregates_from_arrow_table(arrow_table, stats_columns)) data_file = DataFile.from_args( content=DataFileContent.DATA, file_path=file_path, @@ -2785,7 +2985,7 @@ def _check_pyarrow_schema_compatible( f"PyArrow table contains more columns: {', '.join(sorted(additional_names))}. " "Update the schema first (hint, use union_by_name)." ) from e - _check_schema_compatible(requested_schema, provided_schema) + _check_schema_compatible(requested_schema, provided_schema, allow_planar_geospatial_equivalence=True) def parquet_files_to_data_files(io: FileIO, table_metadata: TableMetadata, file_paths: Iterator[str]) -> Iterator[DataFile]: @@ -2803,12 +3003,14 @@ def parquet_file_to_data_file(io: FileIO, table_metadata: TableMetadata, file_pa schema = table_metadata.schema() _check_pyarrow_schema_compatible(schema, arrow_schema, format_version=table_metadata.format_version) + stats_columns = compute_statistics_plan(schema, table_metadata.properties) statistics = data_file_statistics_from_parquet_metadata( parquet_metadata=parquet_metadata, - stats_columns=compute_statistics_plan(schema, table_metadata.properties), + stats_columns=stats_columns, parquet_column_mapping=parquet_path_to_id_mapping(schema), ) + statistics.column_aggregates.update(geospatial_column_aggregates_from_parquet_file(input_file, stats_columns)) data_file = DataFile.from_args( content=DataFileContent.DATA, file_path=file_path, diff --git a/pyiceberg/schema.py b/pyiceberg/schema.py index fd60eb8f94..2b3abd54e0 100644 --- a/pyiceberg/schema.py +++ b/pyiceberg/schema.py @@ -1723,7 +1723,9 @@ def _(file_type: UnknownType, read_type: IcebergType) -> IcebergType: raise ResolveError(f"Cannot promote {file_type} to {read_type}") -def _check_schema_compatible(requested_schema: Schema, provided_schema: Schema) -> None: +def _check_schema_compatible( + requested_schema: Schema, provided_schema: Schema, allow_planar_geospatial_equivalence: bool = False +) -> None: """ Check if the `provided_schema` is compatible with `requested_schema`. @@ -1737,23 +1739,43 @@ def _check_schema_compatible(requested_schema: Schema, provided_schema: Schema) Raises: ValueError: If the schemas are not compatible. """ - pre_order_visit(requested_schema, _SchemaCompatibilityVisitor(provided_schema)) + pre_order_visit( + requested_schema, + _SchemaCompatibilityVisitor( + provided_schema=provided_schema, + allow_planar_geospatial_equivalence=allow_planar_geospatial_equivalence, + ), + ) class _SchemaCompatibilityVisitor(PreOrderSchemaVisitor[bool]): provided_schema: Schema + allow_planar_geospatial_equivalence: bool - def __init__(self, provided_schema: Schema): + def __init__(self, provided_schema: Schema, allow_planar_geospatial_equivalence: bool = False): from rich.console import Console from rich.table import Table as RichTable self.provided_schema = provided_schema + self.allow_planar_geospatial_equivalence = allow_planar_geospatial_equivalence self.rich_table = RichTable(show_header=True, header_style="bold") self.rich_table.add_column("") self.rich_table.add_column("Table field") self.rich_table.add_column("Dataframe field") self.console = Console(record=True) + def _is_planar_geospatial_equivalent(self, lhs: IcebergType, rhs: IcebergType) -> bool: + if not self.allow_planar_geospatial_equivalence: + return False + + if isinstance(lhs, GeometryType) and isinstance(rhs, GeographyType): + return rhs.algorithm == "planar" and lhs.crs == rhs.crs + + if isinstance(lhs, GeographyType) and isinstance(rhs, GeometryType): + return lhs.algorithm == "planar" and lhs.crs == rhs.crs + + return False + def _is_field_compatible(self, lhs: NestedField) -> bool: # Validate nullability first. # An optional field can be missing in the provided schema @@ -1776,6 +1798,9 @@ def _is_field_compatible(self, lhs: NestedField) -> bool: if lhs.field_type == rhs.field_type: self.rich_table.add_row("✅", str(lhs), str(rhs)) return True + elif self._is_planar_geospatial_equivalent(lhs.field_type, rhs.field_type): + self.rich_table.add_row("✅", str(lhs), str(rhs)) + return True # We only check that the parent node is also of the same type. # We check the type of the child nodes when we traverse them later. elif any( diff --git a/pyiceberg/utils/geospatial.py b/pyiceberg/utils/geospatial.py new file mode 100644 index 0000000000..baf4044329 --- /dev/null +++ b/pyiceberg/utils/geospatial.py @@ -0,0 +1,437 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from __future__ import annotations + +import math +import struct +from dataclasses import dataclass + +_WKB_POINT = 1 +_WKB_LINESTRING = 2 +_WKB_POLYGON = 3 +_WKB_MULTIPOINT = 4 +_WKB_MULTILINESTRING = 5 +_WKB_MULTIPOLYGON = 6 +_WKB_GEOMETRYCOLLECTION = 7 + +_EWKB_Z_FLAG = 0x80000000 +_EWKB_M_FLAG = 0x40000000 +_EWKB_SRID_FLAG = 0x20000000 + + +@dataclass(frozen=True) +class GeospatialBound: + x: float + y: float + z: float | None = None + m: float | None = None + + @property + def has_z(self) -> bool: + return self.z is not None + + @property + def has_m(self) -> bool: + return self.m is not None + + +@dataclass(frozen=True) +class GeometryEnvelope: + x_min: float + y_min: float + z_min: float | None + m_min: float | None + x_max: float + y_max: float + z_max: float | None + m_max: float | None + + def to_min_bound(self) -> GeospatialBound: + return GeospatialBound(x=self.x_min, y=self.y_min, z=self.z_min, m=self.m_min) + + def to_max_bound(self) -> GeospatialBound: + return GeospatialBound(x=self.x_max, y=self.y_max, z=self.z_max, m=self.m_max) + + +def serialize_geospatial_bound(bound: GeospatialBound) -> bytes: + if bound.z is None and bound.m is None: + return struct.pack(" GeospatialBound: + if len(raw) == 16: + x, y = struct.unpack(" GeometryEnvelope | None: + reader = _WKBReader(wkb) + accumulator = _EnvelopeAccumulator(is_geography=is_geography) + _parse_geometry(reader, accumulator) + if reader.remaining() != 0: + raise ValueError(f"Trailing bytes found after parsing WKB: {reader.remaining()}") + return accumulator.finish() + + +def merge_envelopes(left: GeometryEnvelope, right: GeometryEnvelope, is_geography: bool) -> GeometryEnvelope: + if is_geography: + x_min, x_max = _merge_longitude_intervals(left.x_min, left.x_max, right.x_min, right.x_max) + else: + x_min, x_max = min(left.x_min, right.x_min), max(left.x_max, right.x_max) + + return GeometryEnvelope( + x_min=x_min, + y_min=min(left.y_min, right.y_min), + z_min=_merge_optional_min(left.z_min, right.z_min), + m_min=_merge_optional_min(left.m_min, right.m_min), + x_max=x_max, + y_max=max(left.y_max, right.y_max), + z_max=_merge_optional_max(left.z_max, right.z_max), + m_max=_merge_optional_max(left.m_max, right.m_max), + ) + + +def _merge_optional_min(left: float | None, right: float | None) -> float | None: + if left is None: + return right + if right is None: + return left + return min(left, right) + + +def _merge_optional_max(left: float | None, right: float | None) -> float | None: + if left is None: + return right + if right is None: + return left + return max(left, right) + + +@dataclass +class _EnvelopeAccumulator: + is_geography: bool + x_min: float | None = None + y_min: float | None = None + z_min: float | None = None + m_min: float | None = None + x_max: float | None = None + y_max: float | None = None + z_max: float | None = None + m_max: float | None = None + longitudes: list[float] | None = None + + def __post_init__(self) -> None: + if self.is_geography: + self.longitudes = [] + + def add_point(self, x: float, y: float, z: float | None, m: float | None) -> None: + if math.isnan(x) or math.isnan(y): + return + + if self.is_geography: + if self.longitudes is None: + self.longitudes = [] + self.longitudes.append(_normalize_longitude(x)) + else: + self.x_min = x if self.x_min is None else min(self.x_min, x) + self.x_max = x if self.x_max is None else max(self.x_max, x) + + self.y_min = y if self.y_min is None else min(self.y_min, y) + self.y_max = y if self.y_max is None else max(self.y_max, y) + + if z is not None and not math.isnan(z): + self.z_min = z if self.z_min is None else min(self.z_min, z) + self.z_max = z if self.z_max is None else max(self.z_max, z) + + if m is not None and not math.isnan(m): + self.m_min = m if self.m_min is None else min(self.m_min, m) + self.m_max = m if self.m_max is None else max(self.m_max, m) + + def finish(self) -> GeometryEnvelope | None: + if self.y_min is None or self.y_max is None: + return None + + if self.is_geography: + if not self.longitudes: + return None + x_min, x_max = _minimal_longitude_interval(self.longitudes) + else: + if self.x_min is None or self.x_max is None: + return None + x_min, x_max = self.x_min, self.x_max + + return GeometryEnvelope( + x_min=x_min, + y_min=self.y_min, + z_min=self.z_min, + m_min=self.m_min, + x_max=x_max, + y_max=self.y_max, + z_max=self.z_max, + m_max=self.m_max, + ) + + +class _WKBReader: + def __init__(self, payload: bytes) -> None: + self._payload = payload + self._offset = 0 + + def remaining(self) -> int: + return len(self._payload) - self._offset + + def read_byte(self) -> int: + self._ensure_size(1) + value = self._payload[self._offset] + self._offset += 1 + return value + + def read_uint32(self, little_endian: bool) -> int: + return int(self._read_fmt("I")) + + def read_double(self, little_endian: bool) -> float: + return float(self._read_fmt("d")) + + def _read_fmt(self, fmt: str) -> float | int: + size = struct.calcsize(fmt) + self._ensure_size(size) + value = struct.unpack_from(fmt, self._payload, self._offset)[0] + self._offset += size + return value + + def _ensure_size(self, expected: int) -> None: + if self._offset + expected > len(self._payload): + raise ValueError("Unexpected end of WKB payload") + + +def _parse_geometry(reader: _WKBReader, accumulator: _EnvelopeAccumulator) -> None: + little_endian = _parse_byte_order(reader.read_byte()) + raw_type = reader.read_uint32(little_endian) + geometry_type, has_z, has_m = _parse_geometry_type(raw_type) + + if raw_type & _EWKB_SRID_FLAG: + reader.read_uint32(little_endian) + + if geometry_type == _WKB_POINT: + _parse_point(reader, accumulator, little_endian, has_z, has_m) + elif geometry_type == _WKB_LINESTRING: + _parse_points(reader, accumulator, little_endian, has_z, has_m) + elif geometry_type == _WKB_POLYGON: + _parse_polygon(reader, accumulator, little_endian, has_z, has_m) + elif geometry_type in (_WKB_MULTIPOINT, _WKB_MULTILINESTRING, _WKB_MULTIPOLYGON, _WKB_GEOMETRYCOLLECTION): + _parse_collection(reader, accumulator, little_endian) + else: + raise ValueError(f"Unsupported WKB geometry type: {geometry_type}") + + +def _parse_byte_order(order: int) -> bool: + if order == 1: + return True + if order == 0: + return False + raise ValueError(f"Unsupported WKB byte order marker: {order}") + + +def _parse_geometry_type(raw_type: int) -> tuple[int, bool, bool]: + has_z = bool(raw_type & _EWKB_Z_FLAG) + has_m = bool(raw_type & _EWKB_M_FLAG) + type_code = raw_type & 0x1FFFFFFF + + if type_code >= 3000: + has_z = True + has_m = True + type_code -= 3000 + elif type_code >= 2000: + has_m = True + type_code -= 2000 + elif type_code >= 1000: + has_z = True + type_code -= 1000 + + return type_code, has_z, has_m + + +def _parse_collection(reader: _WKBReader, accumulator: _EnvelopeAccumulator, little_endian: bool) -> None: + num_geometries = reader.read_uint32(little_endian) + for _ in range(num_geometries): + _parse_geometry(reader, accumulator) + + +def _parse_polygon( + reader: _WKBReader, + accumulator: _EnvelopeAccumulator, + little_endian: bool, + has_z: bool, + has_m: bool, +) -> None: + num_rings = reader.read_uint32(little_endian) + for _ in range(num_rings): + _parse_points(reader, accumulator, little_endian, has_z, has_m) + + +def _parse_points( + reader: _WKBReader, + accumulator: _EnvelopeAccumulator, + little_endian: bool, + has_z: bool, + has_m: bool, +) -> None: + count = reader.read_uint32(little_endian) + for _ in range(count): + x = reader.read_double(little_endian) + y = reader.read_double(little_endian) + if has_z and has_m: + z = reader.read_double(little_endian) + m = reader.read_double(little_endian) + elif has_z: + z = reader.read_double(little_endian) + m = None + elif has_m: + z = None + m = reader.read_double(little_endian) + else: + z = None + m = None + accumulator.add_point(x=x, y=y, z=z, m=m) + + +def _parse_point( + reader: _WKBReader, + accumulator: _EnvelopeAccumulator, + little_endian: bool, + has_z: bool, + has_m: bool, +) -> None: + x = reader.read_double(little_endian) + y = reader.read_double(little_endian) + + if has_z and has_m: + z = reader.read_double(little_endian) + m = reader.read_double(little_endian) + elif has_z: + z = reader.read_double(little_endian) + m = None + elif has_m: + z = None + m = reader.read_double(little_endian) + else: + z = None + m = None + + accumulator.add_point(x=x, y=y, z=z, m=m) + + +def _normalize_longitude(value: float) -> float: + normalized = ((value + 180.0) % 360.0) - 180.0 + if math.isclose(normalized, -180.0) and value > 0: + return 180.0 + return normalized + + +def _to_circle(value: float) -> float: + if math.isclose(value, 180.0): + return 360.0 + return value + 180.0 + + +def _from_circle(value: float) -> float: + if math.isclose(value, 360.0): + return 180.0 + return value - 180.0 + + +def _minimal_longitude_interval(longitudes: list[float]) -> tuple[float, float]: + points = sorted({_to_circle(_normalize_longitude(v)) % 360.0 for v in longitudes}) + if len(points) == 1: + lon = _from_circle(points[0]) + return lon, lon + + max_gap = -1.0 + max_gap_idx = 0 + for idx in range(len(points)): + current = points[idx] + nxt = points[(idx + 1) % len(points)] + (360.0 if idx == len(points) - 1 else 0.0) + gap = nxt - current + if gap > max_gap: + max_gap = gap + max_gap_idx = idx + + start = points[(max_gap_idx + 1) % len(points)] + end = points[max_gap_idx] + return _from_circle(start), _from_circle(end) + + +def _merge_longitude_intervals(left_min: float, left_max: float, right_min: float, right_max: float) -> tuple[float, float]: + segments = _interval_to_segments(left_min, left_max) + _interval_to_segments(right_min, right_max) + merged = _merge_segments(segments) + if not merged: + raise ValueError("Cannot merge empty longitude intervals") + + largest_gap = -1.0 + gap_start = 0.0 + gap_end = 0.0 + for idx in range(len(merged)): + current_end = merged[idx][1] + next_start = merged[(idx + 1) % len(merged)][0] + (360.0 if idx == len(merged) - 1 else 0.0) + gap = next_start - current_end + if gap > largest_gap: + largest_gap = gap + gap_start = current_end + gap_end = next_start + + if largest_gap <= 1e-12: + return -180.0, 180.0 + + start = gap_end % 360.0 + end = gap_start % 360.0 + return _from_circle(start), _from_circle(end) + + +def _interval_to_segments(x_min: float, x_max: float) -> list[tuple[float, float]]: + start = _to_circle(_normalize_longitude(x_min)) + end = _to_circle(_normalize_longitude(x_max)) + + if x_min <= x_max: + return [(start, end)] + return [(start, 360.0), (0.0, end)] + + +def _merge_segments(segments: list[tuple[float, float]]) -> list[tuple[float, float]]: + if not segments: + return [] + + ordered = sorted(segments, key=lambda pair: pair[0]) + merged: list[tuple[float, float]] = [ordered[0]] + for start, end in ordered[1:]: + previous_start, previous_end = merged[-1] + if start <= previous_end: + merged[-1] = (previous_start, max(previous_end, end)) + else: + merged.append((start, end)) + return merged diff --git a/tests/expressions/test_spatial_predicates.py b/tests/expressions/test_spatial_predicates.py new file mode 100644 index 0000000000..f489098861 --- /dev/null +++ b/tests/expressions/test_spatial_predicates.py @@ -0,0 +1,85 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +import struct + +import pytest + +from pyiceberg.expressions import BooleanExpression, BoundSTContains, Not, Reference, STContains, STIntersects +from pyiceberg.expressions.visitors import bind, expression_evaluator, translate_column_names +from pyiceberg.schema import Schema +from pyiceberg.typedef import Record +from pyiceberg.types import GeometryType, IntegerType, NestedField + + +def _point_wkb(x: float, y: float) -> bytes: + return struct.pack(" None: + schema = Schema(NestedField(1, "geom", GeometryType(), required=False), schema_id=1) + expr = STContains("geom", _point_wkb(1.0, 2.0)) + bound = bind(schema, expr, case_sensitive=True) + assert isinstance(bound, BoundSTContains) + assert bound.geometry == _point_wkb(1.0, 2.0) + + +def test_st_contains_bind_fails_for_non_geospatial_field() -> None: + schema = Schema(NestedField(1, "id", IntegerType(), required=False), schema_id=1) + with pytest.raises(TypeError) as exc_info: + bind(schema, STContains("id", _point_wkb(1.0, 2.0)), case_sensitive=True) + assert "geometry/geography" in str(exc_info.value) + + +def test_spatial_predicate_json_parsing() -> None: + expr = BooleanExpression.model_validate({"type": "st-intersects", "term": "geom", "value": _point_wkb(1.0, 2.0)}) + assert isinstance(expr, STIntersects) + assert expr.geometry == _point_wkb(1.0, 2.0) + + +def test_spatial_predicate_invert_returns_not() -> None: + expr = STContains("geom", _point_wkb(1.0, 2.0)) + assert isinstance(~expr, Not) + + +def test_spatial_expression_evaluator_not_implemented() -> None: + schema = Schema(NestedField(1, "geom", GeometryType(), required=False), schema_id=1) + evaluator = expression_evaluator(schema, STContains("geom", _point_wkb(1.0, 2.0)), case_sensitive=True) + with pytest.raises(NotImplementedError) as exc_info: + evaluator(Record(_point_wkb(1.0, 2.0))) + assert "st-contains row-level evaluation is not implemented" in str(exc_info.value) + + +def test_translate_column_names_for_spatial_predicate() -> None: + original_schema = Schema(NestedField(1, "geom_original", GeometryType(), required=False), schema_id=1) + file_schema = Schema(NestedField(1, "geom_file", GeometryType(), required=False), schema_id=1) + + bound_expr = bind(original_schema, STContains("geom_original", _point_wkb(1.0, 2.0)), case_sensitive=True) + translated_expr = translate_column_names(bound_expr, file_schema, case_sensitive=True) + + assert isinstance(translated_expr, STContains) + assert translated_expr.term == Reference("geom_file") + assert translated_expr.geometry == _point_wkb(1.0, 2.0) + + +def test_translate_column_names_for_spatial_predicate_missing_column_raises() -> None: + original_schema = Schema(NestedField(1, "geom", GeometryType(), required=False), schema_id=1) + file_schema = Schema(NestedField(2, "other_col", IntegerType(), required=False), schema_id=1) + bound_expr = bind(original_schema, STContains("geom", _point_wkb(1.0, 2.0)), case_sensitive=True) + + with pytest.raises(NotImplementedError) as exc_info: + translate_column_names(bound_expr, file_schema, case_sensitive=True) + assert "Spatial predicate translation is not supported when source columns are missing" in str(exc_info.value) diff --git a/tests/integration/test_geospatial_integration.py b/tests/integration/test_geospatial_integration.py new file mode 100644 index 0000000000..c03f722d64 --- /dev/null +++ b/tests/integration/test_geospatial_integration.py @@ -0,0 +1,118 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +import struct +from uuid import uuid4 + +import pyarrow as pa +import pytest +from pytest_lazy_fixtures import lf + +from pyiceberg.catalog import Catalog, load_catalog +from pyiceberg.exceptions import NoSuchTableError +from pyiceberg.io.pyarrow import schema_to_pyarrow +from pyiceberg.schema import Schema +from pyiceberg.table import TableProperties +from pyiceberg.table.metadata import SUPPORTED_TABLE_FORMAT_VERSION +from pyiceberg.types import GeographyType, GeometryType, IntegerType, NestedField + + +@pytest.fixture() +def rest_catalog() -> Catalog: + return load_catalog( + "local", + **{ + "type": "rest", + "uri": "http://localhost:8181", + "s3.endpoint": "http://localhost:9000", + "s3.access-key-id": "admin", + "s3.secret-access-key": "password", + }, + ) + + +def _drop_if_exists(catalog: Catalog, identifier: str) -> None: + try: + catalog.drop_table(identifier) + except NoSuchTableError: + pass + + +def _as_bytes(value: object) -> bytes: + if isinstance(value, bytes): + return value + if isinstance(value, bytearray): + return bytes(value) + if isinstance(value, memoryview): + return value.tobytes() + if hasattr(value, "to_wkb"): + return bytes(value.to_wkb()) + raise TypeError(f"Unsupported value type: {type(value)}") + + +@pytest.mark.integration +@pytest.mark.parametrize("catalog", [lf("session_catalog"), lf("rest_catalog")]) +def test_write_read_roundtrip_geospatial(catalog: Catalog) -> None: + identifier = f"default.test_geospatial_roundtrip_{uuid4().hex[:8]}" + _drop_if_exists(catalog, identifier) + + schema = Schema( + NestedField(1, "id", IntegerType(), required=True), + NestedField(2, "geom", GeometryType(), required=False), + NestedField(3, "geog", GeographyType(), required=False), + ) + table = catalog.create_table( + identifier=identifier, + schema=schema, + properties={TableProperties.FORMAT_VERSION: "3"}, + ) + + geom = struct.pack(" None: + identifier = f"default.test_geospatial_evolution_{uuid4().hex[:8]}" + _drop_if_exists(catalog, identifier) + + schema = Schema(NestedField(1, "id", IntegerType(), required=True)) + table = catalog.create_table( + identifier=identifier, + schema=schema, + properties={TableProperties.FORMAT_VERSION: "3"}, + ) + table.update_schema().add_column("geom", GeometryType()).commit() + + reloaded = catalog.load_table(identifier) + assert isinstance(reloaded.schema().find_field("geom").field_type, GeometryType) diff --git a/tests/io/test_pyarrow.py b/tests/io/test_pyarrow.py index 2170741bdd..d499344ea4 100644 --- a/tests/io/test_pyarrow.py +++ b/tests/io/test_pyarrow.py @@ -80,6 +80,7 @@ data_file_statistics_from_parquet_metadata, expression_to_pyarrow, parquet_path_to_id_mapping, + pyarrow_to_schema, schema_to_pyarrow, write_file, ) @@ -659,6 +660,41 @@ def mock_import(name: str, *args: Any, **kwargs: Any) -> Any: sys.modules.update(saved_modules) +def test_geospatial_type_to_pyarrow_without_geoarrow_logs_warning_once(caplog: pytest.LogCaptureFixture) -> None: + """Test missing geoarrow dependency logs a single warning and falls back to binary.""" + import sys + + import pyiceberg.io.pyarrow as pyarrow_io + + pyarrow_io._warn_geoarrow_unavailable.cache_clear() + caplog.set_level(logging.WARNING, logger="pyiceberg.io.pyarrow") + + saved_modules = {} + for mod_name in list(sys.modules.keys()): + if mod_name.startswith("geoarrow"): + saved_modules[mod_name] = sys.modules.pop(mod_name) + + import builtins + + original_import = builtins.__import__ + + def mock_import(name: str, *args: Any, **kwargs: Any) -> Any: + if name.startswith("geoarrow"): + raise ImportError(f"No module named '{name}'") + return original_import(name, *args, **kwargs) + + try: + builtins.__import__ = mock_import + assert visit(GeometryType(), _ConvertToArrowSchema()) == pa.large_binary() + assert visit(GeographyType(), _ConvertToArrowSchema()) == pa.large_binary() + finally: + builtins.__import__ = original_import + sys.modules.update(saved_modules) + + warning_records = [r for r in caplog.records if "geoarrow-pyarrow is not installed" in r.getMessage()] + assert len(warning_records) == 1 + + def test_geometry_type_to_pyarrow_with_geoarrow() -> None: """Test geometry type uses geoarrow WKB extension type when available.""" pytest.importorskip("geoarrow.pyarrow") @@ -701,6 +737,44 @@ def test_geography_type_to_pyarrow_with_geoarrow() -> None: assert result_planar == expected_planar +def test_pyarrow_to_schema_with_geoarrow_wkb_extensions() -> None: + pytest.importorskip("geoarrow.pyarrow") + + iceberg_schema = Schema( + NestedField(1, "geom", GeometryType(), required=False), + NestedField(2, "geog", GeographyType(), required=False), + schema_id=1, + ) + arrow_schema = schema_to_pyarrow(iceberg_schema) + converted = pyarrow_to_schema(arrow_schema, format_version=3) + + assert converted.find_field("geom").field_type == GeometryType() + assert converted.find_field("geog").field_type == GeographyType() + + +def test_check_pyarrow_schema_compatible_allows_planar_geography_geometry_equivalence() -> None: + pytest.importorskip("geoarrow.pyarrow") + + requested_schema = Schema(NestedField(1, "shape", GeographyType("OGC:CRS84", "planar"), required=False), schema_id=1) + provided_arrow_schema = schema_to_pyarrow( + Schema(NestedField(1, "shape", GeometryType("OGC:CRS84"), required=False), schema_id=1) + ) + + _check_pyarrow_schema_compatible(requested_schema, provided_arrow_schema, format_version=3) + + +def test_check_pyarrow_schema_compatible_rejects_spherical_geography_geometry_equivalence() -> None: + pytest.importorskip("geoarrow.pyarrow") + + requested_schema = Schema(NestedField(1, "shape", GeographyType("OGC:CRS84", "spherical"), required=False), schema_id=1) + provided_arrow_schema = schema_to_pyarrow( + Schema(NestedField(1, "shape", GeometryType("OGC:CRS84"), required=False), schema_id=1) + ) + + with pytest.raises(ValueError, match="Mismatch in fields"): + _check_pyarrow_schema_compatible(requested_schema, provided_arrow_schema, format_version=3) + + def test_struct_type_to_pyarrow(table_schema_simple: Schema) -> None: expected = pa.struct( [ diff --git a/tests/io/test_pyarrow_stats.py b/tests/io/test_pyarrow_stats.py index 0e628829eb..2a0cea85cb 100644 --- a/tests/io/test_pyarrow_stats.py +++ b/tests/io/test_pyarrow_stats.py @@ -17,6 +17,7 @@ # pylint: disable=protected-access,unused-argument,redefined-outer-name import math +import struct import tempfile import uuid from dataclasses import asdict, dataclass @@ -44,16 +45,23 @@ STRUCT_INT64, ) from pyiceberg.io.pyarrow import ( + DataFileStatistics, + GeospatialStatsAggregator, MetricModeTypes, MetricsMode, + PyArrowFileIO, PyArrowStatisticsCollector, compute_statistics_plan, data_file_statistics_from_parquet_metadata, + geospatial_column_aggregates_from_arrow_table, + geospatial_column_aggregates_from_parquet_file, match_metrics_mode, + parquet_file_to_data_file, parquet_path_to_id_mapping, schema_to_pyarrow, ) from pyiceberg.manifest import DataFile +from pyiceberg.partitioning import PartitionField, PartitionSpec from pyiceberg.schema import Schema, pre_order_visit from pyiceberg.table.metadata import ( TableMetadata, @@ -61,13 +69,18 @@ TableMetadataV1, TableMetadataV2, ) +from pyiceberg.transforms import IdentityTransform from pyiceberg.types import ( BooleanType, FloatType, + GeographyType, + GeometryType, IntegerType, + NestedField, StringType, ) from pyiceberg.utils.datetime import date_to_days, datetime_to_micros, time_to_micros +from pyiceberg.utils.geospatial import deserialize_geospatial_bound @dataclass(frozen=True) @@ -175,6 +188,48 @@ def construct_test_table( return metadata_collector[0], table_metadata +def construct_geospatial_test_table() -> tuple[pq.FileMetaData, TableMetadataV1 | TableMetadataV2, pa.Table]: + table_metadata = TableMetadataUtil.parse_obj( + { + "format-version": 3, + "location": "s3://bucket/test/location", + "last-column-id": 2, + "current-schema-id": 0, + "schemas": [ + { + "type": "struct", + "schema-id": 0, + "fields": [ + {"id": 1, "name": "geom", "required": False, "type": "geometry"}, + {"id": 2, "name": "geog", "required": False, "type": "geography"}, + ], + } + ], + "default-spec-id": 0, + "partition-specs": [{"spec-id": 0, "fields": []}], + "properties": {}, + } + ) + schema = Schema( + NestedField(1, "geom", GeometryType(), required=False), + NestedField(2, "geog", GeographyType(), required=False), + ) + arrow_schema = schema_to_pyarrow(schema) + + # LINESTRING(1 2, 3 4) + geom = struct.pack(" Schema: @@ -282,6 +337,202 @@ def test_bounds() -> None: assert datafile.upper_bounds[2] == STRUCT_FLOAT.pack(100) +def test_geospatial_bounds_use_bound_serialization() -> None: + metadata, table_metadata, arrow_table = construct_geospatial_test_table() + schema = get_current_schema(table_metadata) + stats_columns = compute_statistics_plan(schema, table_metadata.properties) + statistics = data_file_statistics_from_parquet_metadata( + parquet_metadata=metadata, + stats_columns=stats_columns, + parquet_column_mapping=parquet_path_to_id_mapping(schema), + ) + statistics.column_aggregates.update(geospatial_column_aggregates_from_arrow_table(arrow_table, stats_columns)) + datafile = DataFile.from_args(**statistics.to_serialized_dict()) + + geom_min = deserialize_geospatial_bound(datafile.lower_bounds[1]) + geom_max = deserialize_geospatial_bound(datafile.upper_bounds[1]) + assert geom_min.x == 1.0 + assert geom_min.y == 2.0 + assert geom_max.x == 3.0 + assert geom_max.y == 4.0 + + geog_min = deserialize_geospatial_bound(datafile.lower_bounds[2]) + geog_max = deserialize_geospatial_bound(datafile.upper_bounds[2]) + assert geog_min.x > geog_max.x + assert geog_min.x == 170.0 + assert geog_max.x == -170.0 + assert geog_min.y == 10.0 + assert geog_max.y == 20.0 + + +def test_geospatial_column_aggregates_from_parquet_file() -> None: + schema = Schema( + NestedField(1, "geom", GeometryType(), required=False), + NestedField(2, "geog", GeographyType(), required=False), + ) + stats_columns = compute_statistics_plan(schema, {}) + arrow_schema = schema_to_pyarrow(schema) + geom = struct.pack(" geog_max.x + assert geog_min.x == 170.0 + assert geog_max.x == -170.0 + assert geog_min.y == 10.0 + assert geog_max.y == 20.0 + + +def test_parquet_file_to_data_file_with_geospatial_schema() -> None: + table_metadata = TableMetadataUtil.parse_obj( + { + "format-version": 3, + "location": "s3://bucket/test/location", + "last-column-id": 2, + "current-schema-id": 0, + "schemas": [ + { + "type": "struct", + "schema-id": 0, + "fields": [ + {"id": 1, "name": "geom", "required": False, "type": "geometry"}, + {"id": 2, "name": "geog", "required": False, "type": "geography"}, + ], + } + ], + "default-spec-id": 0, + "partition-specs": [{"spec-id": 0, "fields": []}], + "properties": {}, + } + ) + schema = Schema( + NestedField(1, "geom", GeometryType(), required=False), + NestedField(2, "geog", GeographyType(), required=False), + ) + arrow_schema = schema_to_pyarrow(schema) + geom = struct.pack(" geog_max.x + assert geog_min.x == 170.0 + assert geog_max.x == -170.0 + assert geog_min.y == 10.0 + assert geog_max.y == 20.0 + + +def test_parquet_file_to_data_file_with_planar_geography_schema() -> None: + table_metadata = TableMetadataUtil.parse_obj( + { + "format-version": 3, + "location": "s3://bucket/test/location", + "last-column-id": 1, + "current-schema-id": 0, + "schemas": [ + { + "type": "struct", + "schema-id": 0, + "fields": [ + { + "id": 1, + "name": "geog", + "required": False, + "type": "geography('OGC:CRS84', 'planar')", + } + ], + } + ], + "default-spec-id": 0, + "partition-specs": [{"spec-id": 0, "fields": []}], + "properties": {}, + } + ) + schema = Schema(NestedField(1, "geog", GeographyType("OGC:CRS84", "planar"), required=False)) + arrow_schema = schema_to_pyarrow(schema) + geog = struct.pack(" None: + schema = Schema(NestedField(1, "geom", GeometryType(), required=False)) + partition_spec = PartitionSpec( + PartitionField(source_id=1, field_id=1000, transform=IdentityTransform(), name="geom"), + spec_id=0, + ) + geospatial_agg = GeospatialStatsAggregator(GeometryType()) + geospatial_agg.update_from_wkb(struct.pack(" None: assert match_metrics_mode("none") == MetricsMode(MetricModeTypes.NONE) assert match_metrics_mode("nOnE") == MetricsMode(MetricModeTypes.NONE) diff --git a/tests/utils/test_geospatial.py b/tests/utils/test_geospatial.py new file mode 100644 index 0000000000..aa834b4216 --- /dev/null +++ b/tests/utils/test_geospatial.py @@ -0,0 +1,124 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +import math +import struct + +from pyiceberg.utils.geospatial import ( + GeospatialBound, + deserialize_geospatial_bound, + extract_envelope_from_wkb, + merge_envelopes, + serialize_geospatial_bound, +) + + +def test_geospatial_bound_serde_xy() -> None: + raw = serialize_geospatial_bound(GeospatialBound(x=10.0, y=20.0)) + assert len(raw) == 16 + bound = deserialize_geospatial_bound(raw) + assert bound.x == 10.0 + assert bound.y == 20.0 + assert bound.z is None + assert bound.m is None + + +def test_geospatial_bound_serde_xyz() -> None: + raw = serialize_geospatial_bound(GeospatialBound(x=10.0, y=20.0, z=30.0)) + assert len(raw) == 24 + bound = deserialize_geospatial_bound(raw) + assert bound.x == 10.0 + assert bound.y == 20.0 + assert bound.z == 30.0 + assert bound.m is None + + +def test_geospatial_bound_serde_xym() -> None: + raw = serialize_geospatial_bound(GeospatialBound(x=10.0, y=20.0, m=40.0)) + assert len(raw) == 32 + x, y, z, m = struct.unpack(" None: + raw = serialize_geospatial_bound(GeospatialBound(x=10.0, y=20.0, z=30.0, m=40.0)) + assert len(raw) == 32 + bound = deserialize_geospatial_bound(raw) + assert bound.x == 10.0 + assert bound.y == 20.0 + assert bound.z == 30.0 + assert bound.m == 40.0 + + +def test_extract_envelope_geometry() -> None: + # LINESTRING(170 0, -170 1) + wkb = struct.pack(" None: + # LINESTRING(170 0, -170 1) + wkb = struct.pack(" envelope.x_max + assert envelope.x_min == 170.0 + assert envelope.x_max == -170.0 + assert envelope.y_min == 0.0 + assert envelope.y_max == 1.0 + + +def test_extract_envelope_xyzm_linestring() -> None: + # LINESTRING ZM (0 1 2 3, 4 5 6 7) + wkb = struct.pack(" None: + left = extract_envelope_from_wkb(struct.pack(" merged.x_max + assert merged.x_min == 170.0 + assert merged.x_max == -120.0 + assert merged.y_min == 0.0 + assert merged.y_max == 3.0