From e00f36c4496d5ce5fdf80d9294b511c628d309eb Mon Sep 17 00:00:00 2001 From: Mario Taddeucci Date: Tue, 17 Feb 2026 15:34:25 -0300 Subject: [PATCH] Add mapInPandas and mapInArrow methods to DataFrame class with tests --- duckdb/experimental/spark/_typing.py | 9 + duckdb/experimental/spark/sql/dataframe.py | 199 ++++++++++++++++++ .../fast/spark/test_spark_dataframe_map_in.py | 88 ++++++++ 3 files changed, 296 insertions(+) create mode 100644 tests/fast/spark/test_spark_dataframe_map_in.py diff --git a/duckdb/experimental/spark/_typing.py b/duckdb/experimental/spark/_typing.py index 1ed78ea8..a361bcb8 100644 --- a/duckdb/experimental/spark/_typing.py +++ b/duckdb/experimental/spark/_typing.py @@ -19,7 +19,9 @@ from collections.abc import Iterable, Sized from typing import Callable, TypeVar, Union +import pyarrow from numpy import float32, float64, int32, int64, ndarray +from pandas import DataFrame as PandasDataFrame from typing_extensions import Literal, Protocol, Self F = TypeVar("F", bound=Callable) @@ -30,6 +32,13 @@ NonUDFType = Literal[0] +DataFrameLike = PandasDataFrame + +PandasMapIterFunction = Callable[[Iterable[DataFrameLike]], Iterable[DataFrameLike]] + +ArrowMapIterFunction = Callable[[Iterable[pyarrow.RecordBatch]], Iterable[pyarrow.RecordBatch]] + + class SupportsIAdd(Protocol): def __iadd__(self, other: "SupportsIAdd") -> Self: ... diff --git a/duckdb/experimental/spark/sql/dataframe.py b/duckdb/experimental/spark/sql/dataframe.py index e7519e81..2bfff8fb 100644 --- a/duckdb/experimental/spark/sql/dataframe.py +++ b/duckdb/experimental/spark/sql/dataframe.py @@ -1,4 +1,5 @@ import uuid # noqa: D100 +from collections.abc import Iterable from functools import reduce from keyword import iskeyword from typing import ( @@ -13,6 +14,7 @@ import duckdb from duckdb import ColumnExpression, Expression, StarExpression +from duckdb.experimental.spark.exception import ContributionsAcceptedError from ..errors import PySparkIndexError, PySparkTypeError, PySparkValueError from .column import Column @@ -24,6 +26,7 @@ import pyarrow as pa from pandas.core.frame import DataFrame as PandasDataFrame + from .._typing import ArrowMapIterFunction, PandasMapIterFunction from ._typing import ColumnOrName from .group import GroupedData from .session import SparkSession @@ -1419,5 +1422,201 @@ def cache(self) -> "DataFrame": cached_relation = self.relation.execute() return DataFrame(cached_relation, self.session) + def mapInArrow( + self, + func: "ArrowMapIterFunction", + schema: Union[StructType, str], + barrier: bool = False, + profile: Optional[Any] = None, # noqa: ANN401 + ) -> "DataFrame": + """Maps an iterator of batches in the current :class:`DataFrame` using a Python native + function that is performed on `pyarrow.RecordBatch`\\s both as input and output, + and returns the result as a :class:`DataFrame`. + + This method applies the specified Python function to an iterator of + `pyarrow.RecordBatch`\\s, each representing a batch of rows from the original DataFrame. + The returned iterator of `pyarrow.RecordBatch`\\s are combined as a :class:`DataFrame`. + The size of the function's input and output can be different. Each `pyarrow.RecordBatch` + size can be controlled by `spark.sql.execution.arrow.maxRecordsPerBatch`. + + .. versionadded:: 3.3.0 + + Parameters + ---------- + func : function + a Python native function that takes an iterator of `pyarrow.RecordBatch`\\s, and + outputs an iterator of `pyarrow.RecordBatch`\\s. + schema : :class:`pyspark.sql.types.DataType` or str + the return type of the `func` in PySpark. The value can be either a + :class:`pyspark.sql.types.DataType` object or a DDL-formatted type string. + barrier : bool, optional, default False + Use barrier mode execution, ensuring that all Python workers in the stage will be + launched concurrently. + + .. versionadded: 3.5.0 + + profile : :class:`pyspark.resource.ResourceProfile`. The optional ResourceProfile + to be used for mapInArrow. + + .. versionadded: 4.0.0 + + Examples: + -------- + >>> import pyarrow as pa + >>> df = spark.createDataFrame([(1, 21), (2, 30)], ("id", "age")) + >>> def filter_func(iterator): + ... for batch in iterator: + ... pdf = batch.to_pandas() + ... yield pa.RecordBatch.from_pandas(pdf[pdf.id == 1]) + >>> df.mapInArrow(filter_func, df.schema).show() + +---+---+ + | id|age| + +---+---+ + | 1| 21| + +---+---+ + + Set ``barrier`` to ``True`` to force the ``mapInArrow`` stage running in the + barrier mode, it ensures all Python workers in the stage will be + launched concurrently. + + >>> df.mapInArrow(filter_func, df.schema, barrier=True).collect() + [Row(id=1, age=21)] + + See Also: + -------- + pyspark.sql.functions.pandas_udf + DataFrame.mapInPandas + """ # noqa: D205, D301 + if isinstance(schema, str): + msg = "DDL-formatted type string is not supported yet for the 'schema' parameter." + raise ContributionsAcceptedError(msg) + + if profile is not None: + msg = "ResourceProfile is not supported yet for the 'profile' parameter." + raise ContributionsAcceptedError(msg) + + del barrier # Ignored due duckdb works on single node and doesn't have barrier execution mode. + + import pyarrow as pa + from pyarrow.dataset import dataset + + arrow_schema = self.session.createDataFrame([], schema=schema).toArrow().schema + record_batches = self.relation.fetch_record_batch() + batch_generator = func(record_batches) + reader = pa.RecordBatchReader.from_batches(arrow_schema, batch_generator) + ds = dataset(reader) # noqa: F841 + df = DataFrame(self.session.conn.sql("SELECT * FROM ds"), self.session) + return df + + def mapInPandas( + self, + func: "PandasMapIterFunction", + schema: Union[StructType, str], + barrier: bool = False, + profile: Optional[Any] = None, # noqa: ANN401 + ) -> "DataFrame": + """Maps an iterator of batches in the current :class:`DataFrame` using a Python native + function that is performed on pandas DataFrames both as input and output, + and returns the result as a :class:`DataFrame`. + + This method applies the specified Python function to an iterator of + `pandas.DataFrame`\\s, each representing a batch of rows from the original DataFrame. + The returned iterator of `pandas.DataFrame`\\s are combined as a :class:`DataFrame`. + The size of the function's input and output can be different. Each `pandas.DataFrame` + size can be controlled by `spark.sql.execution.arrow.maxRecordsPerBatch`. + + .. versionadded:: 3.0.0 + + .. versionchanged:: 3.4.0 + Supports Spark Connect. + + Parameters + ---------- + func : function + a Python native function that takes an iterator of `pandas.DataFrame`\\s, and + outputs an iterator of `pandas.DataFrame`\\s. + schema : :class:`pyspark.sql.types.DataType` or str + the return type of the `func` in PySpark. The value can be either a + :class:`pyspark.sql.types.DataType` object or a DDL-formatted type string. + barrier : bool, optional, default False + Use barrier mode execution, ensuring that all Python workers in the stage will be + launched concurrently. + + .. versionadded: 3.5.0 + + profile : :class:`pyspark.resource.ResourceProfile`. The optional ResourceProfile + to be used for mapInPandas. + + .. versionadded: 4.0.0 + + + Examples: + -------- + >>> df = spark.createDataFrame([(1, 21), (2, 30)], ("id", "age")) + + Filter rows with id equal to 1: + + >>> def filter_func(iterator): + ... for pdf in iterator: + ... yield pdf[pdf.id == 1] + >>> df.mapInPandas(filter_func, df.schema).show() + +---+---+ + | id|age| + +---+---+ + | 1| 21| + +---+---+ + + Compute the mean age for each id: + + >>> def mean_age(iterator): + ... for pdf in iterator: + ... yield pdf.groupby("id").mean().reset_index() + >>> df.mapInPandas(mean_age, "id: bigint, age: double").show() + +---+----+ + | id| age| + +---+----+ + | 1|21.0| + | 2|30.0| + +---+----+ + + Add a new column with the double of the age: + + >>> def double_age(iterator): + ... for pdf in iterator: + ... pdf["double_age"] = pdf["age"] * 2 + ... yield pdf + >>> df.mapInPandas(double_age, "id: bigint, age: bigint, double_age: bigint").show() + +---+---+----------+ + | id|age|double_age| + +---+---+----------+ + | 1| 21| 42| + | 2| 30| 60| + +---+---+----------+ + + Set ``barrier`` to ``True`` to force the ``mapInPandas`` stage running in the + barrier mode, it ensures all Python workers in the stage will be + launched concurrently. + + >>> df.mapInPandas(filter_func, df.schema, barrier=True).collect() + [Row(id=1, age=21)] + + See Also: + -------- + pyspark.sql.functions.pandas_udf + DataFrame.mapInArrow + """ # noqa: D205, D301 + import pyarrow as pa + + def _build_arrow_func(pandas_func: "PandasMapIterFunction") -> "ArrowMapIterFunction": + def _map_func(record_batches: Iterable[pa.RecordBatch]) -> Iterable[pa.RecordBatch]: + pandas_iterator = (batch.to_pandas() for batch in record_batches) + pandas_result_gen = pandas_func(pandas_iterator) + batch_iterator = (pa.RecordBatch.from_pandas(pdf) for pdf in pandas_result_gen) + yield from batch_iterator + + return _map_func + + return self.mapInArrow(_build_arrow_func(func), schema, barrier, profile) + __all__ = ["DataFrame"] diff --git a/tests/fast/spark/test_spark_dataframe_map_in.py b/tests/fast/spark/test_spark_dataframe_map_in.py new file mode 100644 index 00000000..92bb4fef --- /dev/null +++ b/tests/fast/spark/test_spark_dataframe_map_in.py @@ -0,0 +1,88 @@ +import numpy as np +import pandas as pd +import pyarrow as pa +import pytest + +_ = pytest.importorskip("duckdb.experimental.spark") + +from spark_namespace.sql import functions as F +from spark_namespace.sql.types import Row + + +class TestDataFrameMapInMethods: + data = ((56, "Carol"), (20, "Alice"), (3, "Dave"), (3, "Anna"), (1, "Ben")) + + def test_map_in_pandas(self, spark): + def filter_func(iterator): + for pdf in iterator: + yield pdf[pdf.age == 3] + + df = spark.createDataFrame(self.data, ["age", "name"]) + df = df.mapInPandas(filter_func, schema=df.schema) + df = df.sort(["age", "name"]) + + expected = [ + Row(age=3, name="Anna"), + Row(age=3, name="Dave"), + ] + + assert df.collect() == expected + + def test_map_in_pandas_empty_result(self, spark): + def filter_func(iterator): + for pdf in iterator: + yield pdf[pdf.age > 100] + + df = spark.createDataFrame(self.data, ["age", "name"]) + df = df.mapInPandas(filter_func, schema=df.schema) + + expected = [] + + assert df.collect() == expected + assert df.schema == spark.createDataFrame([], schema=df.schema).schema + + def test_map_in_pandas_large_dataset_ensure_no_data_loss(self, spark): + def identity_func(iterator): + for pdf in iterator: + pdf = pdf[pdf.id >= 0] # Apply a filter to ensure the DataFrame is evaluated + yield pdf + + n = 10_000_000 + + pandas_df = pd.DataFrame( + { + "id": np.arange(n, dtype=np.int64), + "value_float": np.random.rand(n).astype(np.float32), + "value_int": np.random.randint(0, 1000, size=n, dtype=np.int32), + "category": np.random.randint(0, 10, size=n, dtype=np.int8), + } + ) + + df = spark.createDataFrame(pandas_df) + df = df.mapInPandas(identity_func, schema=df.schema) + # Apply filters to evaluate all dataframe + df = df.filter(F.col("id") <= n).filter(F.col("id") >= 0).filter(F.col("category") >= 0) + + generated_pandas_df = df.toPandas() + total_records = df.count() + + assert total_records == n + assert pandas_df["id"].equals(generated_pandas_df["id"]) + + def test_map_in_arrow(self, spark): + def filter_func(iterator): + for batch in iterator: + df = batch.to_pandas() + df = df[df.age == 3] + yield pa.RecordBatch.from_pandas(df) + + df = spark.createDataFrame(self.data, ["age", "name"]) + df = df.mapInArrow(filter_func, schema=df.schema) + df = df.sort(["age", "name"]) + + expected = [ + Row(age=3, name="Anna"), + Row(age=3, name="Dave"), + ] + + assert df.collect() == expected