-
Notifications
You must be signed in to change notification settings - Fork 29.1k
Description
What type of issue is this?
Bug
Spark version
4.0.1 (with Iceberg 1.10.1)
Describe the bug
When using Storage-Partitioned Join (SPJ) with spark.sql.sources.v2.bucketing.partiallyClusteredDistribution.enabled=true, both dropDuplicates() and Window-based dedup (row_number()) produce incorrect results — duplicate rows that should have been removed survive in the output.
Root cause
Partial clustering splits a partition with many files across multiple tasks to improve parallelism. However, downstream dedup operations (dropDuplicates, row_number() OVER (PARTITION BY ...)) rely on the assumption that all rows with the same partition key are co-located in a single task. Since SPJ eliminates the Exchange (shuffle), each split independently deduplicates, and duplicate partition keys survive across splits.
Steps to reproduce
- Create two Iceberg tables partitioned by
part_key:- Big table: 20 separate appends to partition
p1(= 20 data files), plus 1 append top2 - Small table: 1 append containing both
p1andp2
- Big table: 20 separate appends to partition
- Enable SPJ with partial clustering:
spark.sql.sources.v2.bucketing.enabled = true
spark.sql.sources.v2.bucketing.partiallyClusteredDistribution.enabled = true
spark.sql.iceberg.planning.preserve-data-grouping = true
spark.sql.autoBroadcastJoinThreshold = -1
- Perform a
leftsemijoin onpart_key, thendropDuplicates(["part_key"]) - The physical plan contains no Exchange node, confirming SPJ is active
import shutil
import tempfile
from pyspark.sql import SparkSession, Window, functions as F
from pyspark.sql.types import IntegerType, StringType, StructField, StructType
import pytest
ICEBERG_PKG = "org.apache.iceberg:iceberg-spark-runtime-4.0_2.13:1.10.1"
SCHEMA = StructType([
StructField("id", StringType()),
StructField("part_key", StringType()),
StructField("value", IntegerType()),
StructField("padding", StringType()),
])
@pytest.fixture(scope="module")
def warehouse_dir():
d = tempfile.mkdtemp(prefix="iceberg_spj_test_")
yield d
shutil.rmtree(d, ignore_errors=True)
@pytest.fixture(scope="module")
def spark_iceberg(warehouse_dir):
spark = (
SparkSession.builder.master("local[4]")
.appName("spj_window_dedup_bug")
.config("spark.jars.packages", ICEBERG_PKG)
.config(
"spark.sql.extensions",
"org.apache.iceberg.spark.extensions.IcebergSparkSessionExtensions",
)
.config("spark.sql.catalog.local", "org.apache.iceberg.spark.SparkCatalog")
.config("spark.sql.catalog.local.type", "hadoop")
.config("spark.sql.catalog.local.warehouse", warehouse_dir)
.config("spark.ui.enabled", "false")
.config("spark.driver.memory", "2g")
.config("spark.sql.shuffle.partitions", "4")
.config("spark.sql.adaptive.enabled", "false")
.getOrCreate()
)
spark.sparkContext.setLogLevel("WARN")
yield spark
spark.stop()
@pytest.fixture()
def asymmetric_tables(spark_iceberg):
"""Two Iceberg tables with asymmetric file counts to trigger partial clustering."""
big, small = "local.db.big", "local.db.small"
for t in (big, small):
spark_iceberg.sql(f"DROP TABLE IF EXISTS {t}")
spark_iceberg.sql(f"""
CREATE TABLE {t} (id STRING, part_key STRING, value INT, padding STRING)
USING iceberg PARTITIONED BY (part_key)
""")
# Big table: 20 appends = 20 files for partition p1
for i in range(20):
data = [(f"big_{i}_{j}", "p1", i, "X" * 1000) for j in range(200)]
spark_iceberg.createDataFrame(data, SCHEMA).writeTo(big).append()
spark_iceberg.createDataFrame([("big_other", "p2", 99, "Y")], SCHEMA).writeTo(big).append()
# Small table: 1 append = 1 file
spark_iceberg.createDataFrame(
[("small_0", "p1", 0, "Z"), ("small_1", "p2", 1, "Z")], SCHEMA
).writeTo(small).append()
spark_iceberg.conf.set("spark.sql.sources.v2.bucketing.enabled", "true")
spark_iceberg.conf.set(
"spark.sql.sources.v2.bucketing.partiallyClusteredDistribution.enabled", "true"
)
spark_iceberg.conf.set("spark.sql.iceberg.planning.preserve-data-grouping", "true")
spark_iceberg.conf.set("spark.sql.autoBroadcastJoinThreshold", "-1")
return spark_iceberg.table(big), spark_iceberg.table(small)
def _get_plan(df) -> str:
return df._jdf.queryExecution().executedPlan().toString()
class TestSPJDedupBug:
def test_drop_duplicates_after_join_produces_duplicates(self, asymmetric_tables):
"""dropDuplicates after SPJ join → duplicates survive (expected 2, gets >2)."""
big, small = asymmetric_tables
deduped = big.join(small, on="part_key", how="leftsemi").dropDuplicates(["part_key"])
assert "Exchange" not in _get_plan(deduped)
assert deduped.count() > 2
def test_window_dedup_after_join_produces_duplicates(self, asymmetric_tables):
"""row_number() Window dedup after SPJ join → duplicates survive."""
big, small = asymmetric_tables
joined = big.join(small, on="part_key", how="leftsemi")
w = Window.partitionBy("part_key").orderBy(F.col("value").desc())
deduped = joined.withColumn("_r", F.row_number().over(w)).filter("_r = 1").drop("_r")
assert "Exchange" not in _get_plan(deduped)
assert deduped.count() > 2
def test_disabling_partial_clustering_fixes_it(self, asymmetric_tables, spark_iceberg):
"""Setting partiallyClusteredDistribution.enabled=false → dedup works."""
big, small = asymmetric_tables
spark_iceberg.conf.set(
"spark.sql.sources.v2.bucketing.partiallyClusteredDistribution.enabled", "false"
)
deduped = big.join(small, on="part_key", how="leftsemi").dropDuplicates(["part_key"])
assert deduped.count() == 2
Related
SPARK-38166: Duplicates after task failure in dropDuplicates and repartition