diff --git a/datafusion/substrait/src/logical_plan/consumer/substrait_consumer.rs b/datafusion/substrait/src/logical_plan/consumer/substrait_consumer.rs index 730ceab8ccef3..3b1b43bc09169 100644 --- a/datafusion/substrait/src/logical_plan/consumer/substrait_consumer.rs +++ b/datafusion/substrait/src/logical_plan/consumer/substrait_consumer.rs @@ -366,10 +366,22 @@ pub trait SubstraitConsumer: Send + Sync + Sized { async fn consume_dynamic_parameter( &self, - _expr: &DynamicParameter, + expr: &DynamicParameter, _input_schema: &DFSchema, ) -> datafusion::common::Result { - not_impl_err!("Dynamic Parameter expression not supported") + let id = format!("${}", expr.parameter_reference + 1); + let field = expr + .r#type + .as_ref() + .map(|t| { + super::from_substrait_type_without_names(self, t).map(|dt| { + Arc::new(datafusion::arrow::datatypes::Field::new(&id, dt, true)) + }) + }) + .transpose()?; + Ok(Expr::Placeholder( + datafusion::logical_expr::expr::Placeholder::new_with_field(id, field), + )) } // Outer Schema Stack diff --git a/datafusion/substrait/src/logical_plan/producer/expr/mod.rs b/datafusion/substrait/src/logical_plan/producer/expr/mod.rs index d130961596dc9..b07a0206c53bb 100644 --- a/datafusion/substrait/src/logical_plan/producer/expr/mod.rs +++ b/datafusion/substrait/src/logical_plan/producer/expr/mod.rs @@ -20,6 +20,7 @@ mod cast; mod field_reference; mod if_then; mod literal; +mod placeholder; mod scalar_function; mod singular_or_list; mod subquery; @@ -30,6 +31,7 @@ pub use cast::*; pub use field_reference::*; pub use if_then::*; pub use literal::*; +pub use placeholder::*; pub use scalar_function::*; pub use singular_or_list::*; pub use subquery::*; @@ -142,7 +144,7 @@ pub fn to_substrait_rex( #[expect(deprecated)] Expr::Wildcard { .. } => not_impl_err!("Cannot convert {expr:?} to Substrait"), Expr::GroupingSet(expr) => not_impl_err!("Cannot convert {expr:?} to Substrait"), - Expr::Placeholder(expr) => not_impl_err!("Cannot convert {expr:?} to Substrait"), + Expr::Placeholder(expr) => producer.handle_placeholder(expr, schema), Expr::OuterReferenceColumn(_, _) => { // OuterReferenceColumn requires tracking outer query schema context for correlated // subqueries. This is a complex feature that is not yet implemented. diff --git a/datafusion/substrait/src/logical_plan/producer/expr/placeholder.rs b/datafusion/substrait/src/logical_plan/producer/expr/placeholder.rs new file mode 100644 index 0000000000000..44721abc0ce5f --- /dev/null +++ b/datafusion/substrait/src/logical_plan/producer/expr/placeholder.rs @@ -0,0 +1,68 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use crate::logical_plan::producer::{SubstraitProducer, to_substrait_type}; +use datafusion::common::substrait_err; +use datafusion::logical_expr::expr::Placeholder; +use substrait::proto::expression::RexType; +use substrait::proto::{DynamicParameter, Expression}; + +pub fn from_placeholder( + producer: &mut impl SubstraitProducer, + placeholder: &Placeholder, +) -> datafusion::common::Result { + let parameter_reference = parse_placeholder_index(&placeholder.id)?; + + let r#type = placeholder + .field + .as_ref() + .map(|field| to_substrait_type(producer, field.data_type(), field.is_nullable())) + .transpose()?; + + Ok(Expression { + rex_type: Some(RexType::DynamicParameter(DynamicParameter { + r#type, + parameter_reference, + })), + }) +} + +/// Converts a placeholder id like "$1" into a zero-based parameter index. +/// Substrait uses zero-based `parameter_reference` while DataFusion uses +/// one-based `$N` placeholder ids. +fn parse_placeholder_index(id: &str) -> datafusion::common::Result { + let num_str = id.strip_prefix('$').unwrap_or(id); + match num_str.parse::() { + Ok(n) if n > 0 => Ok(n - 1), + Ok(_) => substrait_err!("Placeholder index must be >= 1, got: {id}"), + Err(_) => substrait_err!("Cannot parse placeholder id as numeric index: {id}"), + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_parse_placeholder_index() { + assert_eq!(parse_placeholder_index("$1").unwrap(), 0); + assert_eq!(parse_placeholder_index("$2").unwrap(), 1); + assert_eq!(parse_placeholder_index("$100").unwrap(), 99); + assert!(parse_placeholder_index("$0").is_err()); + assert!(parse_placeholder_index("$name").is_err()); + } +} diff --git a/datafusion/substrait/src/logical_plan/producer/substrait_producer.rs b/datafusion/substrait/src/logical_plan/producer/substrait_producer.rs index 51d2c0ca8e783..9a6c5317ec044 100644 --- a/datafusion/substrait/src/logical_plan/producer/substrait_producer.rs +++ b/datafusion/substrait/src/logical_plan/producer/substrait_producer.rs @@ -20,17 +20,17 @@ use crate::logical_plan::producer::{ from_aggregate, from_aggregate_function, from_alias, from_between, from_binary_expr, from_case, from_cast, from_column, from_distinct, from_empty_relation, from_exists, from_filter, from_in_list, from_in_subquery, from_join, from_like, from_limit, - from_literal, from_projection, from_repartition, from_scalar_function, - from_scalar_subquery, from_set_comparison, from_sort, from_subquery_alias, - from_table_scan, from_try_cast, from_unary_expr, from_union, from_values, - from_window, from_window_function, to_substrait_rel, to_substrait_rex, + from_literal, from_placeholder, from_projection, from_repartition, + from_scalar_function, from_scalar_subquery, from_set_comparison, from_sort, + from_subquery_alias, from_table_scan, from_try_cast, from_unary_expr, from_union, + from_values, from_window, from_window_function, to_substrait_rel, to_substrait_rex, }; use datafusion::common::{Column, DFSchemaRef, ScalarValue, substrait_err}; use datafusion::execution::SessionState; use datafusion::execution::registry::SerializerRegistry; use datafusion::logical_expr::Subquery; use datafusion::logical_expr::expr::{ - Alias, Exists, InList, InSubquery, SetComparison, WindowFunction, + Alias, Exists, InList, InSubquery, Placeholder, SetComparison, WindowFunction, }; use datafusion::logical_expr::{ Aggregate, Between, BinaryExpr, Case, Cast, Distinct, EmptyRelation, Expr, Extension, @@ -388,6 +388,14 @@ pub trait SubstraitProducer: Send + Sync + Sized { ) -> datafusion::common::Result { from_exists(self, exists, schema) } + + fn handle_placeholder( + &mut self, + placeholder: &Placeholder, + _schema: &DFSchemaRef, + ) -> datafusion::common::Result { + from_placeholder(self, placeholder) + } } pub struct DefaultSubstraitProducer<'a> { diff --git a/datafusion/substrait/tests/cases/roundtrip_logical_plan.rs b/datafusion/substrait/tests/cases/roundtrip_logical_plan.rs index 5dd4aa4e2be91..4b47e8dcebc17 100644 --- a/datafusion/substrait/tests/cases/roundtrip_logical_plan.rs +++ b/datafusion/substrait/tests/cases/roundtrip_logical_plan.rs @@ -1754,6 +1754,140 @@ async fn roundtrip_read_filter() -> Result<()> { roundtrip_verify_read_filter_count("SELECT a FROM data where a < 5", 1).await } +#[tokio::test] +async fn roundtrip_placeholder_sql_filter() -> Result<()> { + let plan = generate_plan_from_sql("SELECT a, b FROM data WHERE a > $1", false, false) + .await?; + + assert_snapshot!( + plan, + @r" + Projection: data.a, data.b + Filter: data.a > $1 + TableScan: data + " + ); + Ok(()) +} + +#[tokio::test] +async fn roundtrip_placeholder_sql_projection() -> Result<()> { + let plan = + generate_plan_from_sql("SELECT a, $1 FROM data WHERE a > $2", false, false) + .await?; + + assert_snapshot!( + plan, + @r" + Projection: data.a, $1 + Filter: data.a > $2 + TableScan: data + " + ); + Ok(()) +} + +#[tokio::test] +async fn roundtrip_placeholder_typed_int64() -> Result<()> { + let ctx = create_context().await?; + + let placeholder = + Expr::Placeholder(datafusion::logical_expr::expr::Placeholder::new_with_field( + "$1".into(), + Some(Arc::new(Field::new("$1", DataType::Int64, true))), + )); + let scan_plan = ctx.table("data").await?.into_optimized_plan()?; + let plan = LogicalPlanBuilder::from(scan_plan) + .filter(col("a").gt(placeholder))? + .build()?; + + let proto = to_substrait_plan(&plan, &ctx.state())?; + + // Verify the producer emits a DynamicParameter in the Substrait proto + let plan_rel = proto.relations.first().unwrap(); + let plan_json = format!("{plan_rel:?}"); + assert!( + plan_json.contains("DynamicParameter"), + "Substrait proto should contain DynamicParameter, got: {plan_json}" + ); + + let plan2 = from_substrait_plan(&ctx.state(), &proto).await?; + + assert_snapshot!( + plan2, + @r" + Filter: data.a > $1 + TableScan: data + " + ); + + assert_eq!(plan.schema(), plan2.schema()); + Ok(()) +} + +#[tokio::test] +async fn roundtrip_placeholder_multiple_typed() -> Result<()> { + let ctx = create_context().await?; + + let p1 = + Expr::Placeholder(datafusion::logical_expr::expr::Placeholder::new_with_field( + "$1".into(), + Some(Arc::new(Field::new("$1", DataType::Int64, true))), + )); + let p2 = + Expr::Placeholder(datafusion::logical_expr::expr::Placeholder::new_with_field( + "$2".into(), + Some(Arc::new(Field::new("$2", DataType::Decimal128(5, 2), true))), + )); + let scan_plan = ctx.table("data").await?.into_optimized_plan()?; + let plan = LogicalPlanBuilder::from(scan_plan) + .filter(col("a").gt(p1).and(col("b").lt(p2)))? + .build()?; + + let proto = to_substrait_plan(&plan, &ctx.state())?; + let plan2 = from_substrait_plan(&ctx.state(), &proto).await?; + + assert_snapshot!( + plan2, + @r" + Filter: data.a > $1 AND data.b < $2 + TableScan: data + " + ); + + assert_eq!(plan.schema(), plan2.schema()); + Ok(()) +} + +#[tokio::test] +async fn roundtrip_placeholder_typed_utf8() -> Result<()> { + let ctx = create_context().await?; + + let placeholder = + Expr::Placeholder(datafusion::logical_expr::expr::Placeholder::new_with_field( + "$1".into(), + Some(Arc::new(Field::new("$1", DataType::Utf8, true))), + )); + let scan_plan = ctx.table("data").await?.into_optimized_plan()?; + let plan = LogicalPlanBuilder::from(scan_plan) + .filter(col("f").eq(placeholder))? + .build()?; + + let proto = to_substrait_plan(&plan, &ctx.state())?; + let plan2 = from_substrait_plan(&ctx.state(), &proto).await?; + + assert_snapshot!( + plan2, + @r" + Filter: data.f = $1 + TableScan: data + " + ); + + assert_eq!(plan.schema(), plan2.schema()); + Ok(()) +} + fn check_post_join_filters(rel: &Rel) -> Result<()> { // search for target_rel and field value in proto match &rel.rel_type {