Skip to content

Commit b46c010

Browse files
authored
[GH-2419] Register Sedona functions as built-in functions to support permanent VIEW creation (#2420)
1 parent 6cc4f6f commit b46c010

File tree

2 files changed

+99
-3
lines changed

2 files changed

+99
-3
lines changed

spark/common/src/main/scala/org/apache/sedona/sql/UDF/AbstractCatalog.scala

Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@ package org.apache.sedona.sql.UDF
2020

2121
import org.apache.spark.sql.{SQLContext, SparkSession, functions}
2222
import org.apache.spark.sql.catalyst.FunctionIdentifier
23+
import org.apache.spark.sql.catalyst.analysis.FunctionRegistry
2324
import org.apache.spark.sql.catalyst.analysis.FunctionRegistry.FunctionBuilder
2425
import org.apache.spark.sql.catalyst.expressions.{ExpectsInputTypes, Expression, ExpressionInfo, Literal}
2526
import org.apache.spark.sql.expressions.Aggregator
@@ -86,9 +87,20 @@ abstract class AbstractCatalog {
8687
functionIdentifier,
8788
expressionInfo,
8889
functionBuilder)
90+
FunctionRegistry.builtin.registerFunction(
91+
functionIdentifier,
92+
expressionInfo,
93+
functionBuilder)
94+
}
95+
aggregateExpressions.foreach { f =>
96+
sparkSession.udf.register(f.getClass.getSimpleName, functions.udaf(f))
97+
FunctionRegistry.builtin.registerFunction(
98+
FunctionIdentifier(f.getClass.getSimpleName),
99+
new ExpressionInfo(f.getClass.getCanonicalName, null, f.getClass.getSimpleName),
100+
(_: Seq[Expression]) =>
101+
throw new UnsupportedOperationException(
102+
s"Aggregate function ${f.getClass.getSimpleName} cannot be used as a regular function"))
89103
}
90-
aggregateExpressions.foreach(f =>
91-
sparkSession.udf.register(f.getClass.getSimpleName, functions.udaf(f)))
92104
}
93105

94106
def dropAll(sparkSession: SparkSession): Unit = {

spark/common/src/test/scala/org/apache/sedona/sql/functionTestScala.scala

Lines changed: 85 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
package org.apache.sedona.sql
2020

2121
import org.apache.commons.codec.binary.Hex
22+
import org.apache.commons.io.FileUtils
2223
import org.apache.sedona.common.FunctionsGeoTools
2324
import org.apache.sedona.sql.implicits._
2425
import org.apache.spark.sql.catalyst.expressions.GenericRowWithSchema
@@ -35,7 +36,8 @@ import org.geotools.api.referencing.FactoryException
3536
import org.scalatest.{GivenWhenThen, Matchers}
3637
import org.xml.sax.InputSource
3738

38-
import java.io.StringReader
39+
import java.io.{File, StringReader}
40+
import java.nio.file.Files
3941
import javax.xml.parsers.DocumentBuilderFactory
4042
import javax.xml.xpath.XPathFactory
4143

@@ -4130,4 +4132,86 @@ class functionTestScala
41304132
squareWithTwoHolesSimplified.first().getAs[org.locationtech.jts.geom.Geometry](0)
41314133
assert(simplifiedMedialAxis != null, "Simplified medial axis should not be null")
41324134
}
4135+
4136+
it("Test that CREATE VIEW fails with multiple temporary Sedona functions") {
4137+
val timestamp = System.currentTimeMillis()
4138+
val tmpDir: String =
4139+
Files.createTempDirectory("sedona_geoparquet_test_").toFile.getAbsolutePath
4140+
4141+
val buildings = sparkSession.sql("""
4142+
SELECT
4143+
ST_GeomFromWKT('POLYGON((0 0, 1 0, 1 1, 0 1, 0 0))') as geom,
4144+
'Building 1' as PROP_ADDR
4145+
UNION ALL
4146+
SELECT
4147+
ST_GeomFromWKT('POLYGON((2 2, 3 2, 3 3, 2 3, 2 2))') as geom,
4148+
'Building 2' as PROP_ADDR
4149+
""")
4150+
buildings.write
4151+
.mode("overwrite")
4152+
.option("path", s"$tmpDir/sedona_test_${timestamp}_buildings")
4153+
.saveAsTable("nyc_buildings_geom_test")
4154+
4155+
val zones = sparkSession.sql("""
4156+
SELECT
4157+
ST_GeomFromWKT('POLYGON((0 0, 5 0, 5 5, 0 5, 0 0))') as zone_geom,
4158+
100.0 as elevation
4159+
""")
4160+
zones.write
4161+
.mode("overwrite")
4162+
.option("path", s"$tmpDir/sedona_test_${timestamp}_zones")
4163+
.saveAsTable("elevation_zones_test")
4164+
4165+
// Attempt to create a permanent VIEW with multiple Sedona functions
4166+
sparkSession.sql("""
4167+
CREATE VIEW nyc_buildings_with_functions AS
4168+
SELECT * FROM (
4169+
SELECT
4170+
nyc_buildings_geom_test.PROP_ADDR AS name,
4171+
nyc_buildings_geom_test.geom AS building_geom,
4172+
avg(elevation_zones_test.elevation) AS elevation
4173+
FROM
4174+
nyc_buildings_geom_test
4175+
JOIN
4176+
elevation_zones_test
4177+
ON
4178+
st_intersects(nyc_buildings_geom_test.geom, elevation_zones_test.zone_geom)
4179+
GROUP BY
4180+
nyc_buildings_geom_test.PROP_ADDR, nyc_buildings_geom_test.geom
4181+
)
4182+
WHERE elevation > 0
4183+
""")
4184+
4185+
// Query the view and assert results
4186+
val result = sparkSession.sql("SELECT * FROM nyc_buildings_with_functions").collect()
4187+
assert(result.length == 2, s"Expected 2 rows, but got ${result.length}")
4188+
4189+
// Assert both buildings are in the result
4190+
val buildingNames = result.map(_.getString(0)).toSet
4191+
assert(buildingNames.contains("Building 1"), "Building 1 should be in the result")
4192+
assert(buildingNames.contains("Building 2"), "Building 2 should be in the result")
4193+
4194+
sparkSession.sql("""
4195+
CREATE VIEW nyc_buildings_envelope_aggr_functions AS
4196+
SELECT
4197+
ST_Envelope_Aggr(nyc_buildings_geom_test.geom) AS building_geom_envelope
4198+
FROM
4199+
nyc_buildings_geom_test
4200+
""")
4201+
4202+
// Query the aggregate view and assert results
4203+
val result_aggr =
4204+
sparkSession.sql("SELECT * FROM nyc_buildings_envelope_aggr_functions").collect()
4205+
assert(result_aggr.length == 1, s"Expected 1 row, but got ${result_aggr.length}")
4206+
4207+
// Assert that the views were created
4208+
val views = sparkSession.sql("SHOW VIEWS").collect()
4209+
val view1Exists = views.exists(row => row.getString(1) == "nyc_buildings_with_functions")
4210+
assert(view1Exists, "View 'nyc_buildings_with_functions' should be created")
4211+
val view2Exists =
4212+
views.exists(row => row.getString(1) == "nyc_buildings_envelope_aggr_functions")
4213+
assert(view2Exists, "View 'nyc_buildings_envelope_aggr_functions' should be created")
4214+
4215+
FileUtils.deleteDirectory(new File(tmpDir))
4216+
}
41334217
}

0 commit comments

Comments
 (0)