Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 34 additions & 0 deletions docs/api/sql/NearestNeighbourSearching.md
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,40 @@ CACHE TABLE knnResult;
SELECT * FROM knnResult WHERE condition;
```

### Optimization Barrier

Use the `barrier` function to prevent filter pushdown and control predicate evaluation order in complex spatial joins. This function creates an optimization barrier by evaluating boolean expressions at runtime.

The `barrier` function takes a boolean expression as a string, followed by pairs of variable names and their values that will be substituted into the expression:

```sql
barrier(expression, var_name1, var_value1, var_name2, var_value2, ...)
```

The placement of filters relative to KNN joins changes the semantic meaning of the query:

- **Filter before KNN**: First filters the data, then finds K nearest neighbors from the filtered subset. This answers "What are the K nearest high-rated restaurants?"
- **Filter after KNN**: First finds K nearest neighbors from all data, then filters those results. This answers "Of the K nearest restaurants, which ones are high-rated?"

### Example

Find the 3 nearest high-rated restaurants to luxury hotels, ensuring the KNN join completes before filtering.

```sql
SELECT
h.name AS hotel,
r.name AS restaurant,
r.rating
FROM hotels AS h
INNER JOIN restaurants AS r
ON ST_KNN(h.geometry, r.geometry, 3, false)
WHERE barrier('rating > 4.0 AND stars >= 4',
'rating', r.rating,
'stars', h.stars)
```

With the barrier function, this query first finds the 3 nearest restaurants to each hotel (regardless of rating), then filters to keep only those pairs where the restaurant has rating > 4.0 and the hotel has stars >= 4. Without the barrier, an optimizer might push the filters down, changing the query to first filter for high-rated restaurants and luxury hotels, then find the 3 nearest among those filtered sets.

### Handling SQL-Defined Tables in ST_KNN Joins:

When creating DataFrames from hard-coded SQL select statements in Sedona, and later using them in `ST_KNN` joins, Sedona may attempt to optimize the query in a way that bypasses the intended kNN join logic. Specifically, if you create DataFrames with hard-coded SQL, such as:
Expand Down
19 changes: 19 additions & 0 deletions python/sedona/spark/sql/st_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -295,6 +295,25 @@ def ST_Azimuth(point_a: ColumnOrName, point_b: ColumnOrName) -> Column:
return _call_st_function("ST_Azimuth", (point_a, point_b))


@validate_argument_types
def barrier(expression: ColumnOrName, *args) -> Column:
"""Prevent filter pushdown and control predicate evaluation order in complex spatial joins.
This function creates an optimization barrier by evaluating boolean expressions at runtime.

:param expression: Boolean expression string to evaluate
:type expression: ColumnOrName
:param args: Variable name and value pairs (var_name1, var_value1, var_name2, var_value2, ...)
:return: Boolean result of the expression evaluation
:rtype: Column

Example:
df.where(barrier('rating > 4.0 AND stars >= 4',
'rating', col('r.rating'),
'stars', col('h.stars')))
"""
return _call_st_function("barrier", (expression,) + args)


@validate_argument_types
def ST_BestSRID(geometry: ColumnOrName) -> Column:
"""Estimates the best SRID (EPSG code) of the geometry.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -249,6 +249,7 @@ object Catalog extends AbstractCatalog {
function[ST_Rotate](),
function[ST_RotateX](),
function[ST_RotateY](),
function[Barrier](),
// Expression for rasters
function[RS_NormalizedDifference](),
function[RS_Mean](),
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,206 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package org.apache.spark.sql.sedona_sql.expressions

import org.apache.spark.sql.catalyst.expressions.Expression
import org.apache.spark.sql.catalyst.expressions.codegen.CodegenFallback
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.types.{DataType, BooleanType, StringType}
import org.apache.spark.unsafe.types.UTF8String
import scala.util.parsing.combinator._

/**
* Barrier function to prevent filter pushdown and control predicate evaluation order. Takes a
* boolean expression string followed by pairs of variable names and their values.
*
* Usage: barrier(expression, var_name1, var_value1, var_name2, var_value2, ...) Example:
* barrier('rating > 4.0 AND stars >= 4', 'rating', r.rating, 'stars', h.stars)
*
* Extends CodegenFallback to prevent Catalyst optimizer from pushing this filter through joins.
* CodegenFallback makes this expression opaque to optimization rules, ensuring it evaluates at
* runtime in its original position within the query plan.
*/
private[apache] case class Barrier(inputExpressions: Seq[Expression])
extends Expression
with CodegenFallback {

override def nullable: Boolean = false

override def dataType: DataType = BooleanType

override def children: Seq[Expression] = inputExpressions

override def eval(input: InternalRow): Any = {
// Get the expression string
val exprString = inputExpressions.head.eval(input) match {
case s: UTF8String => s.toString
case null => throw new IllegalArgumentException("Barrier expression cannot be null")
case other =>
throw new IllegalArgumentException(
s"Barrier expression must be a string, got: ${other.getClass}")
}

// Build variable map from pairs
val varMap = scala.collection.mutable.Map[String, Any]()
var i = 1
while (i < inputExpressions.length) {
if (i + 1 >= inputExpressions.length) {
throw new IllegalArgumentException(
"Barrier function requires pairs of variable names and values")
}

val varName = inputExpressions(i).eval(input) match {
case s: UTF8String => s.toString
case null => throw new IllegalArgumentException("Variable name cannot be null")
case other =>
throw new IllegalArgumentException(
s"Variable name must be a string, got: ${other.getClass}")
}

val varValue = inputExpressions(i + 1).eval(input)
varMap(varName) = varValue
i += 2
}

// Evaluate the expression with variable substitution
evaluateBooleanExpression(exprString, varMap.toMap)
}

/**
* Evaluates a boolean expression string with variable substitution. Supports basic comparison
* operators and logical operators (AND, OR, NOT).
*/
private def evaluateBooleanExpression(
expression: String,
variables: Map[String, Any]): Boolean = {
val parser = new BooleanExpressionParser(variables)
parser.parseExpression(expression) match {
case parser.Success(result, _) => result
case parser.NoSuccess(msg, _) =>
throw new IllegalArgumentException(s"Failed to parse barrier expression: $msg")
}
}

protected def withNewChildrenInternal(newChildren: IndexedSeq[Expression]): Expression = {
copy(inputExpressions = newChildren)
}
}

/**
* Parser for boolean expressions in barrier function. Supports comparison operators: =, !=, <>,
* <, <=, >, >= Supports logical operators: AND, OR, NOT Supports parentheses for grouping
*/
private class BooleanExpressionParser(variables: Map[String, Any]) extends JavaTokenParsers {

// Pre-compiled regex patterns for better performance
private val truePattern = "(?i)true".r
private val falsePattern = "(?i)false".r
private val nullPattern = "(?i)null".r
private val andPattern = "(?i)AND".r
private val orPattern = "(?i)OR".r
private val notPattern = "(?i)NOT".r

def parseExpression(expr: String): ParseResult[Boolean] = parseAll(boolExpr, expr)

def boolExpr: Parser[Boolean] = orExpr

def orExpr: Parser[Boolean] = andExpr ~ rep(orPattern ~> andExpr) ^^ { case left ~ rights =>
rights.foldLeft(left)(_ || _)
}

def andExpr: Parser[Boolean] = notExpr ~ rep(andPattern ~> notExpr) ^^ { case left ~ rights =>
rights.foldLeft(left)(_ && _)
}

def notExpr: Parser[Boolean] =
notPattern ~> notExpr ^^ (!_) |
primaryExpr

def primaryExpr: Parser[Boolean] =
"(" ~> boolExpr <~ ")" |
attempt(comparison) |
booleanValue

def comparison: Parser[Boolean] = value ~ compOp ~ value ^^ { case left ~ op ~ right =>
compareValues(left, op, right)
}

def attempt[T](p: Parser[T]): Parser[T] = Parser { in =>
p(in) match {
case s @ Success(_, _) => s
case _ => Failure("", in)
}
}

def booleanValue: Parser[Boolean] =
truePattern ^^ (_ => true) |
falsePattern ^^ (_ => false) |
ident.filter(id => !id.toUpperCase.matches("AND|OR|NOT")) ^^ { name =>
variables.get(name) match {
case Some(b: Boolean) => b
case Some(other) =>
throw new IllegalArgumentException(s"Expected boolean value for $name, got: $other")
case None =>
throw new IllegalArgumentException(s"Unknown variable: $name")
}
}

def compOp: Parser[String] = ">=" | "<=" | "!=" | "<>" | "=" | ">" | "<"

def value: Parser[Any] =
floatingPointNumber ^^ (_.toDouble) |
wholeNumber ^^ (_.toLong) |
stringLiteral ^^ (s => s.substring(1, s.length - 1)) | // Remove quotes
truePattern ^^ (_ => true) |
falsePattern ^^ (_ => false) |
nullPattern ^^ (_ => null) |
ident.filter(id => !id.toUpperCase.matches("AND|OR|NOT")) ^^ (name =>
variables.getOrElse(name, throw new IllegalArgumentException(s"Unknown variable: $name")))

private def compareValues(left: Any, op: String, right: Any): Boolean = {
(left, right) match {
case (null, null) => op == "=" || op == "<=" || op == ">="
case (null, _) | (_, null) => op == "!=" || op == "<>"
case _ =>
val comparison = compareNonNull(left, right)
op match {
case "=" => comparison == 0
case "!=" | "<>" => comparison != 0
case "<" => comparison < 0
case "<=" => comparison <= 0
case ">" => comparison > 0
case ">=" => comparison >= 0
}
}
}

private def compareNonNull(left: Any, right: Any): Int = {
(left, right) match {
case (l: Number, r: Number) =>
val ld = l.doubleValue()
val rd = r.doubleValue()
if (ld < rd) -1 else if (ld > rd) 1 else 0
case (l: String, r: String) => l.compareTo(r)
case (l: Boolean, r: Boolean) => l.compareTo(r)
case _ =>
// Try to compare as strings as a fallback
left.toString.compareTo(right.toString)
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -1038,4 +1038,13 @@ object st_functions {
selfWeight,
useSpheroid,
attributes)

def barrier(expression: Column, args: Column*): Column = {
val allArgs = expression +: args
wrapExpression[Barrier](allArgs: _*)
}
def barrier(expression: String, args: Any*): Column = {
val allArgs = expression +: args
wrapExpression[Barrier](allArgs: _*)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -597,9 +597,6 @@ class JoinQueryDetector(sparkSession: SparkSession) extends SparkStrategy {
case None =>
Nil
}
val objectSidePlan = if (querySide == LeftSide) right else left

checkObjectPlanFilterPushdown(objectSidePlan)

logInfo(
"Planning knn join, left side is for queries and right size is for the object to be searched")
Expand Down Expand Up @@ -737,10 +734,6 @@ class JoinQueryDetector(sparkSession: SparkSession) extends SparkStrategy {
case None =>
Nil
}
val objectSidePlan = if (querySide == LeftSide) right else left

checkObjectPlanFilterPushdown(objectSidePlan)

if (querySide == broadcastSide.get) {
// broadcast is on query side
return BroadcastQuerySideKNNJoinExec(
Expand Down Expand Up @@ -967,35 +960,4 @@ class JoinQueryDetector(sparkSession: SparkSession) extends SparkStrategy {
case other => other.children.exists(containPlanFilterPushdown)
}
}

/**
* Check if the given plan has a filter that can be pushed down to the object side of the KNN
* join. Print a warning if a filter pushdown is detected.
* @param objectSidePlan
*/
private def checkObjectPlanFilterPushdown(objectSidePlan: LogicalPlan): Unit = {
if (containPlanFilterPushdown(objectSidePlan)) {
val warnings = Seq(
"Warning: One or more filter pushdowns have been detected on the object side of the KNN join. \n" +
"These filters will be applied to the object side reader before the KNN join is executed. \n" +
"If you intend to apply the filters after the KNN join, please ensure that you materialize the KNN join results before applying the filters. \n" +
"For example, you can use the following approach:\n\n" +

// Scala Example
"Scala Example:\n" +
"val knnResult = knnJoinDF.cache()\n" +
"val filteredResult = knnResult.filter(condition)\n\n" +

// SQL Example
"SQL Example:\n" +
"CREATE OR REPLACE TEMP VIEW knnResult AS\n" +
"SELECT * FROM (\n" +
" -- Your KNN join SQL here\n" +
") AS knnView\n" +
"CACHE TABLE knnResult;\n" +
"SELECT * FROM knnResult WHERE condition;")
logWarning(warnings.mkString("\n"))
println(warnings.mkString("\n"))
}
}
}
Loading
Loading