diff --git a/crates/polars-expr/src/expressions/column.rs b/crates/polars-expr/src/expressions/column.rs index f45a036cba9d..8c2ea88a0660 100644 --- a/crates/polars-expr/src/expressions/column.rs +++ b/crates/polars-expr/src/expressions/column.rs @@ -1,7 +1,7 @@ use std::borrow::Cow; use polars_core::prelude::*; -use polars_plan::constants::CSE_REPLACED; +use polars_plan::constants::{CSE_REPLACED, POLARS_ELEMENT}; use super::*; use crate::expressions::{AggregationContext, PartitionedAggregation, PhysicalExpr}; @@ -141,6 +141,22 @@ impl PhysicalExpr for ColumnExpr { groups: &'a GroupPositions, state: &ExecutionState, ) -> PolarsResult> { + if let Some(state) = state.ext_named_groups.get(&self.name) { + match state { + AggState::LiteralScalar(c) => assert_eq!(c.len(), 1), + AggState::AggregatedScalar(c) => assert_eq!(c.len(), groups.len()), + AggState::AggregatedList(c) => assert_eq!(c.len(), groups.len()), + AggState::NotAggregated(_) => {}, + } + + return Ok(AggregationContext { + state: state.clone(), + groups: Cow::Borrowed(groups), + update_groups: UpdateGroups::No, + original_len: false, + }); + } + let c = self.evaluate(df, state)?; Ok(AggregationContext::new(c, Cow::Borrowed(groups), false)) } diff --git a/crates/polars-expr/src/expressions/eval.rs b/crates/polars-expr/src/expressions/eval.rs index 100bdf55d7c5..de9ace1eeb88 100644 --- a/crates/polars-expr/src/expressions/eval.rs +++ b/crates/polars-expr/src/expressions/eval.rs @@ -10,7 +10,7 @@ use polars_core::frame::DataFrame; use polars_core::prelude::ArrayChunked; use polars_core::prelude::{ AnyValue, ChunkCast, ChunkExplode, Column, Field, GroupPositions, GroupsType, IntoColumn, - ListBuilderTrait, ListChunked, + ListBuilderTrait, ListChunked, PlHashMap, }; use polars_core::schema::Schema; use polars_core::series::Series; @@ -19,7 +19,7 @@ use polars_utils::IdxSize; use polars_utils::pl_str::PlSmallStr; use rayon::iter::{IntoParallelIterator, ParallelIterator}; -use super::{AggregationContext, PhysicalExpr}; +use super::{AggState, AggregationContext, PhysicalExpr}; use crate::state::ExecutionState; #[derive(Clone)] @@ -34,6 +34,7 @@ pub struct EvalExpr { evaluation_is_scalar: bool, evaluation_is_elementwise: bool, evaluation_is_fallible: bool, + uses_ext_columns: bool, } impl EvalExpr { @@ -49,6 +50,7 @@ impl EvalExpr { evaluation_is_scalar: bool, evaluation_is_elementwise: bool, evaluation_is_fallible: bool, + uses_ext_columns: bool, ) -> Self { Self { input, @@ -61,6 +63,7 @@ impl EvalExpr { evaluation_is_scalar, evaluation_is_elementwise, evaluation_is_fallible, + uses_ext_columns, } } @@ -69,6 +72,7 @@ impl EvalExpr { ca: &ListChunked, state: &ExecutionState, is_agg: bool, + ext_df: &DataFrame, ) -> PolarsResult { let df = ca.get_inner().with_name(PlSmallStr::EMPTY).into_frame(); @@ -78,11 +82,14 @@ impl EvalExpr { return Ok(Column::full_null(name, ca.len(), self.output_field.dtype())); } - let has_masked_out_values = LazyCell::new(|| ca.has_masked_out_values()); - let may_fail_on_masked_out_elements = self.evaluation_is_fallible && *has_masked_out_values; + let has_masked_out_values = ca.has_masked_out_values(); + let may_fail_on_masked_out_elements = self.evaluation_is_fallible && has_masked_out_values; // Fast path: fully elementwise expression without masked out values. - if self.evaluation_is_elementwise && !may_fail_on_masked_out_elements { + if self.evaluation_is_elementwise + && !self.uses_ext_columns + && !may_fail_on_masked_out_elements + { let mut column = self.evaluation.evaluate(&df, state)?; // Since `lit` is marked as elementwise, this may lead to problems. @@ -124,7 +131,17 @@ impl EvalExpr { }; let groups = Cow::Owned(groups.into_sliceable()); - let mut ac = self.evaluation.evaluate_on_groups(&df, &groups, state)?; + let mut state = Cow::Borrowed(state); + if self.uses_ext_columns { + state.to_mut().ext_named_groups = + Arc::new(PlHashMap::from_iter(ext_df.column_iter().map(|c| { + (c.name().clone(), AggState::AggregatedScalar(c.clone())) + }))); + } + + let mut ac = self + .evaluation + .evaluate_on_groups(&df, &groups, state.as_ref())?; ac.groups(); // Update the groups. @@ -405,11 +422,11 @@ impl PhysicalExpr for EvalExpr { match self.variant { EvalVariant::List => { let lst = input.list()?; - self.evaluate_on_list_chunked(lst, state, false) + self.evaluate_on_list_chunked(lst, state, false, df) }, EvalVariant::ListAgg => { let lst = input.list()?; - self.evaluate_on_list_chunked(lst, state, true) + self.evaluate_on_list_chunked(lst, state, true, df) }, EvalVariant::Array { as_list } => feature_gated!("dtype-array", { self.evaluate_on_array_chunked(input.array()?, state, as_list, false) @@ -433,11 +450,12 @@ impl PhysicalExpr for EvalExpr { match self.variant { EvalVariant::List => { let out = - self.evaluate_on_list_chunked(input.get_values().list()?, state, false)?; + self.evaluate_on_list_chunked(input.get_values().list()?, state, false, df)?; input.with_values(out, false, Some(&self.expr))?; }, EvalVariant::ListAgg => { - let out = self.evaluate_on_list_chunked(input.get_values().list()?, state, true)?; + let out = + self.evaluate_on_list_chunked(input.get_values().list()?, state, true, df)?; input.with_values(out, false, Some(&self.expr))?; }, EvalVariant::Array { as_list } => feature_gated!("dtype-array", { diff --git a/crates/polars-expr/src/planner.rs b/crates/polars-expr/src/planner.rs index dfb861d92978..c2fc5b92be7e 100644 --- a/crates/polars-expr/src/planner.rs +++ b/crates/polars-expr/src/planner.rs @@ -521,6 +521,11 @@ fn create_physical_expr_inner( pd_group.update_with_expr_rec(expr_arena.get(*evaluation), expr_arena, None); let evaluation_is_fallible = matches!(pd_group, ExprPushdownGroup::Fallible); + let has_ext_columns = aexpr_to_leaf_names_iter(*evaluation, expr_arena) + .filter(|n| n.is_empty()) + .count() + > 0; + let output_field = expr_arena .get(expression) .to_field(&ToFieldContext::new(expr_arena, schema))?; @@ -531,7 +536,8 @@ fn create_physical_expr_inner( create_physical_expr_inner(*expr, Context::Default, expr_arena, schema, state)?; let element_dtype = variant.element_dtype(&input_field.dtype)?; - let eval_schema = Schema::from_iter([(PlSmallStr::EMPTY, element_dtype.clone())]); + let mut eval_schema = schema.as_ref().clone(); + eval_schema.insert(PlSmallStr::EMPTY, element_dtype.clone()); let evaluation = create_physical_expr_inner( *evaluation, // @Hack. Since EvalVariant::Array uses `evaluate_on_groups` to determine the @@ -561,6 +567,7 @@ fn create_physical_expr_inner( evaluation_is_scalar, evaluation_is_elementwise, evaluation_is_fallible, + has_ext_columns, ))) }, Function { diff --git a/crates/polars-expr/src/state/execution_state.rs b/crates/polars-expr/src/state/execution_state.rs index da71186732b8..b6ceaa2da323 100644 --- a/crates/polars-expr/src/state/execution_state.rs +++ b/crates/polars-expr/src/state/execution_state.rs @@ -11,6 +11,7 @@ use polars_utils::relaxed_cell::RelaxedCell; use polars_utils::unique_id::UniqueId; use super::NodeTimer; +use crate::prelude::{AggState, AggregationContext}; pub type JoinTuplesCache = Arc>>; @@ -118,6 +119,8 @@ pub struct ExecutionState { pub branch_idx: usize, pub flags: RelaxedCell, pub ext_contexts: Arc>, + /// External aggregations that can be provided by name. + pub ext_named_groups: Arc>, node_timer: Option, stop: Arc>, } @@ -135,6 +138,7 @@ impl ExecutionState { branch_idx: 0, flags: RelaxedCell::from(StateFlags::init().as_u8()), ext_contexts: Default::default(), + ext_named_groups: Default::default(), node_timer: None, stop: Arc::new(RelaxedCell::from(false)), } @@ -199,6 +203,7 @@ impl ExecutionState { branch_idx: self.branch_idx, flags: self.flags.clone(), ext_contexts: self.ext_contexts.clone(), + ext_named_groups: self.ext_named_groups.clone(), node_timer: self.node_timer.clone(), stop: self.stop.clone(), } diff --git a/crates/polars-plan/src/constants.rs b/crates/polars-plan/src/constants.rs index 44016a6d1c91..307a80e87de8 100644 --- a/crates/polars-plan/src/constants.rs +++ b/crates/polars-plan/src/constants.rs @@ -5,12 +5,14 @@ use polars_utils::pl_str::PlSmallStr; pub static CSE_REPLACED: &str = "__POLARS_CSER_"; pub static POLARS_TMP_PREFIX: &str = "_POLARS_"; pub static POLARS_PLACEHOLDER: &str = "_POLARS_<>"; +pub static POLARS_ELEMENT: &str = "__PL_ELEMENT"; pub const LEN: &str = "len"; const LITERAL_NAME: &str = "literal"; // Cache the often used LITERAL and LEN constants static LITERAL_NAME_INIT: OnceLock = OnceLock::new(); static LEN_INIT: OnceLock = OnceLock::new(); +pub static PL_ELEMENT_NAME: PlSmallStr = PlSmallStr::from_static(POLARS_ELEMENT); pub fn get_literal_name() -> &'static PlSmallStr { LITERAL_NAME_INIT.get_or_init(|| PlSmallStr::from_static(LITERAL_NAME)) diff --git a/crates/polars-plan/src/plans/aexpr/schema.rs b/crates/polars-plan/src/plans/aexpr/schema.rs index 46d8d337fec6..915ca4388355 100644 --- a/crates/polars-plan/src/plans/aexpr/schema.rs +++ b/crates/polars-plan/src/plans/aexpr/schema.rs @@ -248,7 +248,10 @@ impl AExpr { let field = ctx.arena.get(*expr).to_field_impl(ctx)?; let element_dtype = variant.element_dtype(field.dtype())?; - let schema = Schema::from_iter([(PlSmallStr::EMPTY, element_dtype.clone())]); + let mut schema = ctx.schema.clone(); + schema.insert(PlSmallStr::EMPTY, element_dtype.clone()); + + dbg!(&schema); let ctx = ToFieldContext { schema: &schema, diff --git a/crates/polars-plan/src/plans/conversion/dsl_to_ir/expr_expansion.rs b/crates/polars-plan/src/plans/conversion/dsl_to_ir/expr_expansion.rs index 89c73555bb8b..2c5345c234bc 100644 --- a/crates/polars-plan/src/plans/conversion/dsl_to_ir/expr_expansion.rs +++ b/crates/polars-plan/src/plans/conversion/dsl_to_ir/expr_expansion.rs @@ -785,17 +785,6 @@ fn expand_expression_rec( evaluation, variant, } => { - // Perform this before schema resolution so that we can better error messages. - for e in evaluation.as_ref().into_iter() { - if let Expr::Column(name) = e { - polars_ensure!( - name.is_empty(), - ComputeError: - "named columns are not allowed in `eval` functions; consider using `element`" - ); - } - } - let mut tmp = Vec::with_capacity(1); expand_expression_rec(expr, ignored_selector_columns, schema, &mut tmp, opt_flags)?; diff --git a/crates/polars-plan/src/plans/conversion/dsl_to_ir/expr_to_ir.rs b/crates/polars-plan/src/plans/conversion/dsl_to_ir/expr_to_ir.rs index 861f01fea78f..c33444fee3b1 100644 --- a/crates/polars-plan/src/plans/conversion/dsl_to_ir/expr_to_ir.rs +++ b/crates/polars-plan/src/plans/conversion/dsl_to_ir/expr_to_ir.rs @@ -1,5 +1,6 @@ use super::functions::convert_functions; use super::*; +use crate::constants::PL_ELEMENT_NAME; use crate::plans::iterator::ArenaExprIter; pub fn to_expr_ir(expr: Expr, ctx: &mut ExprToIRContext) -> PolarsResult { @@ -416,18 +417,9 @@ pub(super) fn to_aexpr_impl( let expr_dtype = ctx.arena.get(expr).to_dtype(&ctx.to_field_ctx())?; let element_dtype = variant.element_dtype(&expr_dtype)?; - // Perform this before schema resolution so that we can better error messages. - for e in evaluation.as_ref().into_iter() { - if let Expr::Column(name) = e { - polars_ensure!( - name.is_empty(), - ComputeError: - "named columns are not allowed in `eval` functions; consider using `element`" - ); - } - } + let mut evaluation_schema = ctx.schema.clone(); + evaluation_schema.insert(PlSmallStr::EMPTY, element_dtype.clone()); - let evaluation_schema = Schema::from_iter([(PlSmallStr::EMPTY, element_dtype.clone())]); let mut evaluation_ctx = ExprToIRContext { with_fields: None, schema: &evaluation_schema, diff --git a/crates/polars-plan/src/plans/conversion/stack_opt.rs b/crates/polars-plan/src/plans/conversion/stack_opt.rs index 0415fff6b512..fe87af36a6a3 100644 --- a/crates/polars-plan/src/plans/conversion/stack_opt.rs +++ b/crates/polars-plan/src/plans/conversion/stack_opt.rs @@ -144,7 +144,8 @@ impl ConversionOptimizer { .to_dtype(&ToFieldContext::new(expr_arena, schema))?; let element_dtype = variant.element_dtype(&expr)?; - let schema = Schema::from_iter([(PlSmallStr::EMPTY, element_dtype.clone())]); + let mut schema = schema.clone(); + schema.insert(PlSmallStr::EMPTY, element_dtype.clone()); self.schemas.push(schema); self.scratch.push((*evaluation, self.schemas.len())); }