Skip to content

Commit 2b36847

Browse files
Hari Kishore Chaparalabsharifi
authored andcommitted
Bug fix: Attribute nullability for set operations
1 parent 46a2e9d commit 2b36847

File tree

2 files changed

+79
-2
lines changed

2 files changed

+79
-2
lines changed

src/it/scala/io/github/spark_redshift_community/spark/redshift/pushdown/PushdownLogicalPlanOperatorSuite.scala

Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1225,6 +1225,47 @@ abstract class PushdownLogicalPlanOperatorSuite extends IntegrationPushdownSuite
12251225
doTest(sqlContext, testUnion13)
12261226
}
12271227

1228+
test("Test UNION ALL attribute nullability") {
1229+
withTwoTempRedshiftTables("tableA", "tableB") { (tableA, tableB) =>
1230+
redshiftWrapper.executeUpdate(conn,
1231+
s"create table $tableA (id INT not null, category varchar(50))")
1232+
redshiftWrapper.executeUpdate(conn,
1233+
s"create table $tableB (id INT, category varchar(50))")
1234+
1235+
redshiftWrapper.executeUpdate(conn,
1236+
s"insert into $tableA VALUES (1, 'A'), (1, 'B'), (2, 'A')")
1237+
redshiftWrapper.executeUpdate(conn,
1238+
s"insert into $tableB VALUES (1, 'C'), (1, 'D'), (NULL, 'A')")
1239+
1240+
read.option("dbtable", tableA).load.createOrReplaceTempView(tableA)
1241+
read.option("dbtable", tableB).load.createOrReplaceTempView(tableB)
1242+
1243+
val strQuery =
1244+
s"select id, count(*) as cnt from " +
1245+
s"(select id, category from $tableA union all select id, category from $tableB) " +
1246+
s"group by rollup (id) order by id, cnt"
1247+
1248+
// If the nullability is wrong, the connector will return a row of (0, 1) instead of (null, 1)
1249+
// because Spark will misapply the non-nullability of the first table to the second table and
1250+
// convert the null column value into a zero [Redshift-87788]
1251+
checkAnswer(
1252+
sqlContext.sql(strQuery),
1253+
Seq(Row(null, 1),
1254+
Row(null, 6),
1255+
Row(1, 4),
1256+
Row(2, 1))
1257+
)
1258+
1259+
// We don't expect the group by rollup expression to be pushed down.
1260+
checkSqlStatement(
1261+
s"""( SELECT ( "SQ_0"."ID" ) AS "SQ_1_COL_0" FROM
1262+
| ( SELECT * FROM "PUBLIC"."$tableA" AS "RCQ_ALIAS" ) AS "SQ_0" ) UNION ALL
1263+
| ( SELECT ( "SQ_0"."ID" ) AS "SQ_1_COL_0" FROM
1264+
| ( SELECT * FROM "PUBLIC"."$tableB" AS "RCQ_ALIAS" ) AS "SQ_0" )""".stripMargin
1265+
)
1266+
}
1267+
}
1268+
12281269
// No push down for except
12291270
test("Test EXCEPT logical plan operator") {
12301271
// "Column name" and result set

src/main/scala/io/github/spark_redshift_community/spark/redshift/pushdown/querygeneration/RedshiftQuery.scala

Lines changed: 38 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -789,17 +789,53 @@ case class SetQuery(children: Seq[LogicalPlan],
789789
new QueryBuilder(child).treeRoot
790790
}
791791

792-
if(queries.contains(null)) {
792+
if (queries.contains(null)) {
793793
throw new RedshiftPushdownUnsupportedException(
794794
RedshiftFailMessage.FAIL_PUSHDOWN_STATEMENT,
795795
setOperation,
796796
"Not all " + setOperation + " children query supported",
797797
false)
798798
}
799+
800+
private val childrenOutputs: Seq[Seq[Attribute]] = queries.map(_.helper.output)
801+
802+
private val resolvedAttributes: Seq[Attribute] = setOperation match {
803+
case "UNION ALL" =>
804+
// For UNION ALL: resolved nullable = OR of all child nullabilities
805+
childrenOutputs.head.indices.map { i =>
806+
val childNullabilities = childrenOutputs.map(_.apply(i).nullable)
807+
val resolvedNullable = childNullabilities.contains(true)
808+
val baseAttr = childrenOutputs.head(i)
809+
baseAttr.withNullability(resolvedNullable)
810+
}
811+
812+
case "EXCEPT" =>
813+
// For EXCEPT: resolved nullability = from the first query
814+
childrenOutputs.head
815+
816+
case "INTERSECT" =>
817+
// For INTERSECT: resolved nullable = AND of all child nullabilities
818+
// If any child is non-nullable (false), resolved is non-nullable (false).
819+
childrenOutputs.head.indices.map { i =>
820+
val childNullabilities = childrenOutputs.map(_.apply(i).nullable)
821+
val resolvedNullable = childNullabilities.forall(_ == true)
822+
val baseAttr = childrenOutputs.head(i)
823+
baseAttr.withNullability(resolvedNullable)
824+
}
825+
826+
case other =>
827+
throw new RedshiftPushdownUnsupportedException(
828+
"Unsupported set query pushdown",
829+
other,
830+
"Only UNION ALL, INTERSECT and EXCEPT are supported",
831+
true
832+
)
833+
}
834+
799835
override val helper: QueryHelper =
800836
QueryHelper(
801837
children = queries,
802-
outputAttributes = Some(queries.head.helper.output),
838+
outputAttributes = Some(resolvedAttributes),
803839
alias = alias,
804840
visibleAttributeOverride =
805841
Some(queries.foldLeft(Seq.empty[Attribute])((x, y) => x ++ y.helper.output).map(

0 commit comments

Comments
 (0)