Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 9 additions & 0 deletions duckdb/experimental/spark/_typing.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,9 @@
from collections.abc import Iterable, Sized
from typing import Callable, TypeVar, Union

import pyarrow
from numpy import float32, float64, int32, int64, ndarray
from pandas import DataFrame as PandasDataFrame
from typing_extensions import Literal, Protocol, Self

F = TypeVar("F", bound=Callable)
Expand All @@ -30,6 +32,13 @@
NonUDFType = Literal[0]


DataFrameLike = PandasDataFrame

PandasMapIterFunction = Callable[[Iterable[DataFrameLike]], Iterable[DataFrameLike]]

ArrowMapIterFunction = Callable[[Iterable[pyarrow.RecordBatch]], Iterable[pyarrow.RecordBatch]]


class SupportsIAdd(Protocol):
def __iadd__(self, other: "SupportsIAdd") -> Self: ...

Expand Down
199 changes: 199 additions & 0 deletions duckdb/experimental/spark/sql/dataframe.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import uuid # noqa: D100
from collections.abc import Iterable
from functools import reduce
from keyword import iskeyword
from typing import (
Expand All @@ -13,6 +14,7 @@

import duckdb
from duckdb import ColumnExpression, Expression, StarExpression
from duckdb.experimental.spark.exception import ContributionsAcceptedError

from ..errors import PySparkIndexError, PySparkTypeError, PySparkValueError
from .column import Column
Expand All @@ -24,6 +26,7 @@
import pyarrow as pa
from pandas.core.frame import DataFrame as PandasDataFrame

from .._typing import ArrowMapIterFunction, PandasMapIterFunction
from ._typing import ColumnOrName
from .group import GroupedData
from .session import SparkSession
Expand Down Expand Up @@ -1419,5 +1422,201 @@ def cache(self) -> "DataFrame":
cached_relation = self.relation.execute()
return DataFrame(cached_relation, self.session)

def mapInArrow(
self,
func: "ArrowMapIterFunction",
schema: Union[StructType, str],
barrier: bool = False,
profile: Optional[Any] = None, # noqa: ANN401
) -> "DataFrame":
"""Maps an iterator of batches in the current :class:`DataFrame` using a Python native
function that is performed on `pyarrow.RecordBatch`\\s both as input and output,
and returns the result as a :class:`DataFrame`.

This method applies the specified Python function to an iterator of
`pyarrow.RecordBatch`\\s, each representing a batch of rows from the original DataFrame.
The returned iterator of `pyarrow.RecordBatch`\\s are combined as a :class:`DataFrame`.
The size of the function's input and output can be different. Each `pyarrow.RecordBatch`
size can be controlled by `spark.sql.execution.arrow.maxRecordsPerBatch`.

.. versionadded:: 3.3.0

Parameters
----------
func : function
a Python native function that takes an iterator of `pyarrow.RecordBatch`\\s, and
outputs an iterator of `pyarrow.RecordBatch`\\s.
schema : :class:`pyspark.sql.types.DataType` or str
the return type of the `func` in PySpark. The value can be either a
:class:`pyspark.sql.types.DataType` object or a DDL-formatted type string.
barrier : bool, optional, default False
Use barrier mode execution, ensuring that all Python workers in the stage will be
launched concurrently.

.. versionadded: 3.5.0

profile : :class:`pyspark.resource.ResourceProfile`. The optional ResourceProfile
to be used for mapInArrow.

.. versionadded: 4.0.0

Examples:
--------
>>> import pyarrow as pa
>>> df = spark.createDataFrame([(1, 21), (2, 30)], ("id", "age"))
>>> def filter_func(iterator):
... for batch in iterator:
... pdf = batch.to_pandas()
... yield pa.RecordBatch.from_pandas(pdf[pdf.id == 1])
>>> df.mapInArrow(filter_func, df.schema).show()
+---+---+
| id|age|
+---+---+
| 1| 21|
+---+---+

Set ``barrier`` to ``True`` to force the ``mapInArrow`` stage running in the
barrier mode, it ensures all Python workers in the stage will be
launched concurrently.

>>> df.mapInArrow(filter_func, df.schema, barrier=True).collect()
[Row(id=1, age=21)]

See Also:
--------
pyspark.sql.functions.pandas_udf
DataFrame.mapInPandas
""" # noqa: D205, D301
if isinstance(schema, str):
msg = "DDL-formatted type string is not supported yet for the 'schema' parameter."
raise ContributionsAcceptedError(msg)

if profile is not None:
msg = "ResourceProfile is not supported yet for the 'profile' parameter."
raise ContributionsAcceptedError(msg)

del barrier # Ignored due duckdb works on single node and doesn't have barrier execution mode.

import pyarrow as pa
from pyarrow.dataset import dataset

arrow_schema = self.session.createDataFrame([], schema=schema).toArrow().schema
record_batches = self.relation.fetch_record_batch()
batch_generator = func(record_batches)
reader = pa.RecordBatchReader.from_batches(arrow_schema, batch_generator)
ds = dataset(reader) # noqa: F841
df = DataFrame(self.session.conn.sql("SELECT * FROM ds"), self.session)
return df

def mapInPandas(
self,
func: "PandasMapIterFunction",
schema: Union[StructType, str],
barrier: bool = False,
profile: Optional[Any] = None, # noqa: ANN401
) -> "DataFrame":
"""Maps an iterator of batches in the current :class:`DataFrame` using a Python native
function that is performed on pandas DataFrames both as input and output,
and returns the result as a :class:`DataFrame`.

This method applies the specified Python function to an iterator of
`pandas.DataFrame`\\s, each representing a batch of rows from the original DataFrame.
The returned iterator of `pandas.DataFrame`\\s are combined as a :class:`DataFrame`.
The size of the function's input and output can be different. Each `pandas.DataFrame`
size can be controlled by `spark.sql.execution.arrow.maxRecordsPerBatch`.

.. versionadded:: 3.0.0

.. versionchanged:: 3.4.0
Supports Spark Connect.

Parameters
----------
func : function
a Python native function that takes an iterator of `pandas.DataFrame`\\s, and
outputs an iterator of `pandas.DataFrame`\\s.
schema : :class:`pyspark.sql.types.DataType` or str
the return type of the `func` in PySpark. The value can be either a
:class:`pyspark.sql.types.DataType` object or a DDL-formatted type string.
barrier : bool, optional, default False
Use barrier mode execution, ensuring that all Python workers in the stage will be
launched concurrently.

.. versionadded: 3.5.0

profile : :class:`pyspark.resource.ResourceProfile`. The optional ResourceProfile
to be used for mapInPandas.

.. versionadded: 4.0.0


Examples:
--------
>>> df = spark.createDataFrame([(1, 21), (2, 30)], ("id", "age"))

Filter rows with id equal to 1:

>>> def filter_func(iterator):
... for pdf in iterator:
... yield pdf[pdf.id == 1]
>>> df.mapInPandas(filter_func, df.schema).show()
+---+---+
| id|age|
+---+---+
| 1| 21|
+---+---+

Compute the mean age for each id:

>>> def mean_age(iterator):
... for pdf in iterator:
... yield pdf.groupby("id").mean().reset_index()
>>> df.mapInPandas(mean_age, "id: bigint, age: double").show()
+---+----+
| id| age|
+---+----+
| 1|21.0|
| 2|30.0|
+---+----+

Add a new column with the double of the age:

>>> def double_age(iterator):
... for pdf in iterator:
... pdf["double_age"] = pdf["age"] * 2
... yield pdf
>>> df.mapInPandas(double_age, "id: bigint, age: bigint, double_age: bigint").show()
+---+---+----------+
| id|age|double_age|
+---+---+----------+
| 1| 21| 42|
| 2| 30| 60|
+---+---+----------+

Set ``barrier`` to ``True`` to force the ``mapInPandas`` stage running in the
barrier mode, it ensures all Python workers in the stage will be
launched concurrently.

>>> df.mapInPandas(filter_func, df.schema, barrier=True).collect()
[Row(id=1, age=21)]

See Also:
--------
pyspark.sql.functions.pandas_udf
DataFrame.mapInArrow
""" # noqa: D205, D301
import pyarrow as pa

def _build_arrow_func(pandas_func: "PandasMapIterFunction") -> "ArrowMapIterFunction":
def _map_func(record_batches: Iterable[pa.RecordBatch]) -> Iterable[pa.RecordBatch]:
pandas_iterator = (batch.to_pandas() for batch in record_batches)
pandas_result_gen = pandas_func(pandas_iterator)
batch_iterator = (pa.RecordBatch.from_pandas(pdf) for pdf in pandas_result_gen)
yield from batch_iterator

return _map_func

return self.mapInArrow(_build_arrow_func(func), schema, barrier, profile)


__all__ = ["DataFrame"]
88 changes: 88 additions & 0 deletions tests/fast/spark/test_spark_dataframe_map_in.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,88 @@
import numpy as np
import pandas as pd
import pyarrow as pa
import pytest

_ = pytest.importorskip("duckdb.experimental.spark")

from spark_namespace.sql import functions as F
from spark_namespace.sql.types import Row


class TestDataFrameMapInMethods:
data = ((56, "Carol"), (20, "Alice"), (3, "Dave"), (3, "Anna"), (1, "Ben"))

def test_map_in_pandas(self, spark):
def filter_func(iterator):
for pdf in iterator:
yield pdf[pdf.age == 3]

df = spark.createDataFrame(self.data, ["age", "name"])
df = df.mapInPandas(filter_func, schema=df.schema)
df = df.sort(["age", "name"])

expected = [
Row(age=3, name="Anna"),
Row(age=3, name="Dave"),
]

assert df.collect() == expected

def test_map_in_pandas_empty_result(self, spark):
def filter_func(iterator):
for pdf in iterator:
yield pdf[pdf.age > 100]

df = spark.createDataFrame(self.data, ["age", "name"])
df = df.mapInPandas(filter_func, schema=df.schema)

expected = []

assert df.collect() == expected
assert df.schema == spark.createDataFrame([], schema=df.schema).schema

def test_map_in_pandas_large_dataset_ensure_no_data_loss(self, spark):
def identity_func(iterator):
for pdf in iterator:
pdf = pdf[pdf.id >= 0] # Apply a filter to ensure the DataFrame is evaluated
yield pdf

n = 10_000_000

pandas_df = pd.DataFrame(
{
"id": np.arange(n, dtype=np.int64),
"value_float": np.random.rand(n).astype(np.float32),
"value_int": np.random.randint(0, 1000, size=n, dtype=np.int32),
"category": np.random.randint(0, 10, size=n, dtype=np.int8),
}
)

df = spark.createDataFrame(pandas_df)
df = df.mapInPandas(identity_func, schema=df.schema)
# Apply filters to evaluate all dataframe
df = df.filter(F.col("id") <= n).filter(F.col("id") >= 0).filter(F.col("category") >= 0)

generated_pandas_df = df.toPandas()
total_records = df.count()

assert total_records == n
assert pandas_df["id"].equals(generated_pandas_df["id"])

def test_map_in_arrow(self, spark):
def filter_func(iterator):
for batch in iterator:
df = batch.to_pandas()
df = df[df.age == 3]
yield pa.RecordBatch.from_pandas(df)

df = spark.createDataFrame(self.data, ["age", "name"])
df = df.mapInArrow(filter_func, schema=df.schema)
df = df.sort(["age", "name"])

expected = [
Row(age=3, name="Anna"),
Row(age=3, name="Dave"),
]

assert df.collect() == expected
Loading