Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -366,10 +366,22 @@ pub trait SubstraitConsumer: Send + Sync + Sized {

async fn consume_dynamic_parameter(
&self,
_expr: &DynamicParameter,
expr: &DynamicParameter,
_input_schema: &DFSchema,
) -> datafusion::common::Result<Expr> {
not_impl_err!("Dynamic Parameter expression not supported")
let id = format!("${}", expr.parameter_reference + 1);
let field = expr
.r#type
.as_ref()
.map(|t| {
super::from_substrait_type_without_names(self, t).map(|dt| {
Arc::new(datafusion::arrow::datatypes::Field::new(&id, dt, true))
})
})
.transpose()?;
Ok(Expr::Placeholder(
datafusion::logical_expr::expr::Placeholder::new_with_field(id, field),
))
}

// Outer Schema Stack
Expand Down
4 changes: 3 additions & 1 deletion datafusion/substrait/src/logical_plan/producer/expr/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ mod cast;
mod field_reference;
mod if_then;
mod literal;
mod placeholder;
mod scalar_function;
mod singular_or_list;
mod subquery;
Expand All @@ -30,6 +31,7 @@ pub use cast::*;
pub use field_reference::*;
pub use if_then::*;
pub use literal::*;
pub use placeholder::*;
pub use scalar_function::*;
pub use singular_or_list::*;
pub use subquery::*;
Expand Down Expand Up @@ -142,7 +144,7 @@ pub fn to_substrait_rex(
#[expect(deprecated)]
Expr::Wildcard { .. } => not_impl_err!("Cannot convert {expr:?} to Substrait"),
Expr::GroupingSet(expr) => not_impl_err!("Cannot convert {expr:?} to Substrait"),
Expr::Placeholder(expr) => not_impl_err!("Cannot convert {expr:?} to Substrait"),
Expr::Placeholder(expr) => producer.handle_placeholder(expr, schema),
Expr::OuterReferenceColumn(_, _) => {
// OuterReferenceColumn requires tracking outer query schema context for correlated
// subqueries. This is a complex feature that is not yet implemented.
Expand Down
68 changes: 68 additions & 0 deletions datafusion/substrait/src/logical_plan/producer/expr/placeholder.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.

use crate::logical_plan::producer::{SubstraitProducer, to_substrait_type};
use datafusion::common::substrait_err;
use datafusion::logical_expr::expr::Placeholder;
use substrait::proto::expression::RexType;
use substrait::proto::{DynamicParameter, Expression};

pub fn from_placeholder(
producer: &mut impl SubstraitProducer,
placeholder: &Placeholder,
) -> datafusion::common::Result<Expression> {
let parameter_reference = parse_placeholder_index(&placeholder.id)?;

let r#type = placeholder
.field
.as_ref()
.map(|field| to_substrait_type(producer, field.data_type(), field.is_nullable()))
.transpose()?;

Ok(Expression {
rex_type: Some(RexType::DynamicParameter(DynamicParameter {
r#type,
parameter_reference,
})),
})
}

/// Converts a placeholder id like "$1" into a zero-based parameter index.
/// Substrait uses zero-based `parameter_reference` while DataFusion uses
/// one-based `$N` placeholder ids.
fn parse_placeholder_index(id: &str) -> datafusion::common::Result<u32> {
let num_str = id.strip_prefix('$').unwrap_or(id);
match num_str.parse::<u32>() {
Ok(n) if n > 0 => Ok(n - 1),
Ok(_) => substrait_err!("Placeholder index must be >= 1, got: {id}"),
Err(_) => substrait_err!("Cannot parse placeholder id as numeric index: {id}"),
}
}

#[cfg(test)]
mod tests {
use super::*;

#[test]
fn test_parse_placeholder_index() {
assert_eq!(parse_placeholder_index("$1").unwrap(), 0);
assert_eq!(parse_placeholder_index("$2").unwrap(), 1);
assert_eq!(parse_placeholder_index("$100").unwrap(), 99);
assert!(parse_placeholder_index("$0").is_err());
assert!(parse_placeholder_index("$name").is_err());
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -20,17 +20,17 @@ use crate::logical_plan::producer::{
from_aggregate, from_aggregate_function, from_alias, from_between, from_binary_expr,
from_case, from_cast, from_column, from_distinct, from_empty_relation, from_exists,
from_filter, from_in_list, from_in_subquery, from_join, from_like, from_limit,
from_literal, from_projection, from_repartition, from_scalar_function,
from_scalar_subquery, from_set_comparison, from_sort, from_subquery_alias,
from_table_scan, from_try_cast, from_unary_expr, from_union, from_values,
from_window, from_window_function, to_substrait_rel, to_substrait_rex,
from_literal, from_placeholder, from_projection, from_repartition,
from_scalar_function, from_scalar_subquery, from_set_comparison, from_sort,
from_subquery_alias, from_table_scan, from_try_cast, from_unary_expr, from_union,
from_values, from_window, from_window_function, to_substrait_rel, to_substrait_rex,
};
use datafusion::common::{Column, DFSchemaRef, ScalarValue, substrait_err};
use datafusion::execution::SessionState;
use datafusion::execution::registry::SerializerRegistry;
use datafusion::logical_expr::Subquery;
use datafusion::logical_expr::expr::{
Alias, Exists, InList, InSubquery, SetComparison, WindowFunction,
Alias, Exists, InList, InSubquery, Placeholder, SetComparison, WindowFunction,
};
use datafusion::logical_expr::{
Aggregate, Between, BinaryExpr, Case, Cast, Distinct, EmptyRelation, Expr, Extension,
Expand Down Expand Up @@ -388,6 +388,14 @@ pub trait SubstraitProducer: Send + Sync + Sized {
) -> datafusion::common::Result<Expression> {
from_exists(self, exists, schema)
}

fn handle_placeholder(
&mut self,
placeholder: &Placeholder,
_schema: &DFSchemaRef,
) -> datafusion::common::Result<Expression> {
from_placeholder(self, placeholder)
}
}

pub struct DefaultSubstraitProducer<'a> {
Expand Down
134 changes: 134 additions & 0 deletions datafusion/substrait/tests/cases/roundtrip_logical_plan.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1754,6 +1754,140 @@ async fn roundtrip_read_filter() -> Result<()> {
roundtrip_verify_read_filter_count("SELECT a FROM data where a < 5", 1).await
}

#[tokio::test]
async fn roundtrip_placeholder_sql_filter() -> Result<()> {
let plan = generate_plan_from_sql("SELECT a, b FROM data WHERE a > $1", false, false)
.await?;

assert_snapshot!(
plan,
@r"
Projection: data.a, data.b
Filter: data.a > $1
TableScan: data
"
);
Ok(())
}

#[tokio::test]
async fn roundtrip_placeholder_sql_projection() -> Result<()> {
let plan =
generate_plan_from_sql("SELECT a, $1 FROM data WHERE a > $2", false, false)
.await?;

assert_snapshot!(
plan,
@r"
Projection: data.a, $1
Filter: data.a > $2
TableScan: data
"
);
Ok(())
}

#[tokio::test]
async fn roundtrip_placeholder_typed_int64() -> Result<()> {
let ctx = create_context().await?;

let placeholder =
Expr::Placeholder(datafusion::logical_expr::expr::Placeholder::new_with_field(
"$1".into(),
Some(Arc::new(Field::new("$1", DataType::Int64, true))),
));
let scan_plan = ctx.table("data").await?.into_optimized_plan()?;
let plan = LogicalPlanBuilder::from(scan_plan)
.filter(col("a").gt(placeholder))?
.build()?;

let proto = to_substrait_plan(&plan, &ctx.state())?;

// Verify the producer emits a DynamicParameter in the Substrait proto
let plan_rel = proto.relations.first().unwrap();
let plan_json = format!("{plan_rel:?}");
assert!(
plan_json.contains("DynamicParameter"),
"Substrait proto should contain DynamicParameter, got: {plan_json}"
);

let plan2 = from_substrait_plan(&ctx.state(), &proto).await?;

assert_snapshot!(
plan2,
@r"
Filter: data.a > $1
TableScan: data
"
);

assert_eq!(plan.schema(), plan2.schema());
Ok(())
}

#[tokio::test]
async fn roundtrip_placeholder_multiple_typed() -> Result<()> {
let ctx = create_context().await?;

let p1 =
Expr::Placeholder(datafusion::logical_expr::expr::Placeholder::new_with_field(
"$1".into(),
Some(Arc::new(Field::new("$1", DataType::Int64, true))),
));
let p2 =
Expr::Placeholder(datafusion::logical_expr::expr::Placeholder::new_with_field(
"$2".into(),
Some(Arc::new(Field::new("$2", DataType::Decimal128(5, 2), true))),
));
let scan_plan = ctx.table("data").await?.into_optimized_plan()?;
let plan = LogicalPlanBuilder::from(scan_plan)
.filter(col("a").gt(p1).and(col("b").lt(p2)))?
.build()?;

let proto = to_substrait_plan(&plan, &ctx.state())?;
let plan2 = from_substrait_plan(&ctx.state(), &proto).await?;

assert_snapshot!(
plan2,
@r"
Filter: data.a > $1 AND data.b < $2
TableScan: data
"
);

assert_eq!(plan.schema(), plan2.schema());
Ok(())
}

#[tokio::test]
async fn roundtrip_placeholder_typed_utf8() -> Result<()> {
let ctx = create_context().await?;

let placeholder =
Expr::Placeholder(datafusion::logical_expr::expr::Placeholder::new_with_field(
"$1".into(),
Some(Arc::new(Field::new("$1", DataType::Utf8, true))),
));
let scan_plan = ctx.table("data").await?.into_optimized_plan()?;
let plan = LogicalPlanBuilder::from(scan_plan)
.filter(col("f").eq(placeholder))?
.build()?;

let proto = to_substrait_plan(&plan, &ctx.state())?;
let plan2 = from_substrait_plan(&ctx.state(), &proto).await?;

assert_snapshot!(
plan2,
@r"
Filter: data.f = $1
TableScan: data
"
);

assert_eq!(plan.schema(), plan2.schema());
Ok(())
}

fn check_post_join_filters(rel: &Rel) -> Result<()> {
// search for target_rel and field value in proto
match &rel.rel_type {
Expand Down
Loading