Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
162 changes: 155 additions & 7 deletions duckdb/experimental/spark/sql/dataframe.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
import uuid # noqa: D100
import itertools # noqa: D100
import uuid
from collections.abc import Iterable, Iterator
from functools import reduce
from keyword import iskeyword
from typing import (
Expand All @@ -20,6 +22,8 @@
from .type_utils import duckdb_to_spark_schema
from .types import Row, StructType

_LOCAL_ITERATOR_BATCH_SIZE = 10_000

if TYPE_CHECKING:
import pyarrow as pa
from pandas.core.frame import DataFrame as PandasDataFrame
Expand All @@ -31,6 +35,12 @@
from duckdb.experimental.spark.sql import functions as spark_sql_functions


def _construct_row(values: Iterable, names: list[str]) -> Row:
row = tuple.__new__(Row, list(values))
row.__fields__ = list(names)
return row


class DataFrame: # noqa: D101
def __init__(self, relation: duckdb.DuckDBPyRelation, session: "SparkSession") -> None: # noqa: D107
self.relation = relation
Expand Down Expand Up @@ -71,6 +81,149 @@ def toArrow(self) -> "pa.Table":
"""
return self.relation.to_arrow_table()

def toLocalIterator(self, prefetchPartitions: bool = False) -> Iterator[Row]:
"""Returns an iterator that contains all of the rows in this :class:`DataFrame`.
Comment on lines +84 to +85

The iterator will consume as much memory as the largest partition in this
:class:`DataFrame`. With prefetch it may consume up to the memory of the 2 largest
partitions.

.. versionadded:: 2.0.0

.. versionchanged:: 3.4.0
Supports Spark Connect.

Parameters
----------
prefetchPartitions : bool, optional
If Spark should pre-fetch the next partition before it is needed.

.. versionchanged:: 3.4.0
This argument does not take effect for Spark Connect.

Returns:
-------
Iterator
Iterator of rows.

Examples:
--------
>>> df = spark.createDataFrame([(14, "Tom"), (23, "Alice"), (16, "Bob")], ["age", "name"])
>>> list(df.toLocalIterator())
[Row(age=14, name='Tom'), Row(age=23, name='Alice'), Row(age=16, name='Bob')]
"""
columns = self.relation.columns
cur = self.relation.execute()

try:
while rows := cur.fetchmany(_LOCAL_ITERATOR_BATCH_SIZE):
yield from (_construct_row(x, columns) for x in rows)
finally:
cur.close()

def foreach(self, f: Callable[[Row], None]) -> None:
"""Applies the ``f`` function to all :class:`Row` of this :class:`DataFrame`.

This is a shorthand for ``df.rdd.foreach()``.

.. versionadded:: 1.3.0

.. versionchanged:: 4.0.0
Supports Spark Connect.

Parameters
----------
f : function
A function that accepts one parameter which will
receive each row to process.

Examples:
--------
>>> df = spark.createDataFrame([(14, "Tom"), (23, "Alice"), (16, "Bob")], ["age", "name"])
>>> def func(person):
... print(person.name)
>>> df.foreach(func)
"""
for row in self.toLocalIterator():
f(row)

def foreachPartition(self, f: Callable[[Iterator[Row]], None]) -> None:
"""Applies the ``f`` function to each partition of this :class:`DataFrame`.

This a shorthand for ``df.rdd.foreachPartition()``.

.. versionadded:: 1.3.0

.. versionchanged:: 4.0.0
Supports Spark Connect.

Parameters
----------
f : function
A function that accepts one parameter which will receive
each partition to process.

Examples:
--------
>>> df = spark.createDataFrame([(14, "Tom"), (23, "Alice"), (16, "Bob")], ["age", "name"])
>>> def func(itr):
... for person in itr:
... print(person.name)
>>> df.foreachPartition(func)
"""
rows_generator = self.toLocalIterator()
while rows := itertools.islice(rows_generator, _LOCAL_ITERATOR_BATCH_SIZE):
f(iter(rows))

def isEmpty(self) -> bool:
"""Checks if the :class:`DataFrame` is empty and returns a boolean value.

.. versionadded:: 3.3.0

.. versionchanged:: 3.4.0
Supports Spark Connect.

Returns:
-------
bool
Returns ``True`` if the DataFrame is empty, ``False`` otherwise.

See Also:
--------
DataFrame.count : Counts the number of rows in DataFrame.

Notes:
-----
- An empty DataFrame has no rows. It may have columns, but no data.

Examples:
--------
Example 1: Checking if an empty DataFrame is empty

>>> df_empty = spark.createDataFrame([], "a STRING")
>>> df_empty.isEmpty()
True

Example 2: Checking if a non-empty DataFrame is empty

>>> df_non_empty = spark.createDataFrame(["a"], "STRING")
>>> df_non_empty.isEmpty()
False

Example 3: Checking if a DataFrame with null values is empty

>>> df_nulls = spark.createDataFrame([(None, None)], "a STRING, b INT")
>>> df_nulls.isEmpty()
False

Example 4: Checking if a DataFrame with no rows but with columns is empty

>>> df_no_rows = spark.createDataFrame([], "id INT, value STRING")
>>> df_no_rows.isEmpty()
True
"""
return self.first() is None

def createOrReplaceTempView(self, name: str) -> None:
"""Creates or replaces a local temporary view with this :class:`DataFrame`.

Expand Down Expand Up @@ -1381,12 +1534,7 @@ def collect(self) -> list[Row]: # noqa: D102
columns = self.relation.columns
result = self.relation.fetchall()

def construct_row(values: list, names: list[str]) -> Row:
row = tuple.__new__(Row, list(values))
row.__fields__ = list(names)
return row

rows = [construct_row(x, columns) for x in result]
rows = [_construct_row(x, columns) for x in result]
return rows

def cache(self) -> "DataFrame":
Expand Down
39 changes: 39 additions & 0 deletions tests/fast/spark/test_spark_dataframe.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from unittest import mock

import pytest

_ = pytest.importorskip("duckdb.experimental.spark")
Expand Down Expand Up @@ -597,3 +599,40 @@ def test_treeString_array_type(self, spark):
assert " |-- name:" in tree
assert " |-- hobbies: array<" in tree
assert "(nullable = true)" in tree

def test_method_is_empty(self, spark):
data = [(1, "Alice"), (2, "Bob")]
df = spark.createDataFrame(data, ["id", "name"])
empty_df = spark.createDataFrame([], schema=df.schema)

assert not df.isEmpty()
assert empty_df.isEmpty()

def test_dataframe_foreach(self, spark):
data = [(56, "Carol"), (20, "Alice"), (3, "Dave")]
df = spark.createDataFrame(data, ["age", "name"])
expected = [Row(age=56, name="Carol"), Row(age=20, name="Alice"), Row(age=3, name="Dave")]

mock_callable = mock.MagicMock()
df.foreach(mock_callable)
mock_callable.assert_has_calls(
[mock.call(expected[0]), mock.call(expected[1]), mock.call(expected[2])],
any_order=True,
)

def test_dataframe_foreach_partition(self, spark):
data = [(56, "Carol"), (20, "Alice"), (3, "Dave")]
df = spark.createDataFrame(data, ["age", "name"])
expected = [Row(age=56, name="Carol"), Row(age=20, name="Alice"), Row(age=3, name="Dave")]

mock_callable = mock.MagicMock()
df.foreachPartition(mock_callable)
mock_callable.assert_called_once_with(expected)

def test_to_local_iterator(self, spark):
data = [(56, "Carol"), (20, "Alice"), (3, "Dave")]
df = spark.createDataFrame(data, ["age", "name"])
expected = [Row(age=56, name="Carol"), Row(age=20, name="Alice"), Row(age=3, name="Dave")]

res = list(df.toLocalIterator())
assert res == expected
Loading