Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 17 additions & 1 deletion crates/polars-expr/src/expressions/column.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
use std::borrow::Cow;

use polars_core::prelude::*;
use polars_plan::constants::CSE_REPLACED;
use polars_plan::constants::{CSE_REPLACED, POLARS_ELEMENT};

use super::*;
use crate::expressions::{AggregationContext, PartitionedAggregation, PhysicalExpr};
Expand Down Expand Up @@ -141,6 +141,22 @@ impl PhysicalExpr for ColumnExpr {
groups: &'a GroupPositions,
state: &ExecutionState,
) -> PolarsResult<AggregationContext<'a>> {
if let Some(state) = state.ext_named_groups.get(&self.name) {
match state {
AggState::LiteralScalar(c) => assert_eq!(c.len(), 1),
AggState::AggregatedScalar(c) => assert_eq!(c.len(), groups.len()),
AggState::AggregatedList(c) => assert_eq!(c.len(), groups.len()),
AggState::NotAggregated(_) => {},
}

return Ok(AggregationContext {
state: state.clone(),
groups: Cow::Borrowed(groups),
update_groups: UpdateGroups::No,
original_len: false,
});
}

let c = self.evaluate(df, state)?;
Ok(AggregationContext::new(c, Cow::Borrowed(groups), false))
}
Expand Down
38 changes: 28 additions & 10 deletions crates/polars-expr/src/expressions/eval.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ use polars_core::frame::DataFrame;
use polars_core::prelude::ArrayChunked;
use polars_core::prelude::{
AnyValue, ChunkCast, ChunkExplode, Column, Field, GroupPositions, GroupsType, IntoColumn,
ListBuilderTrait, ListChunked,
ListBuilderTrait, ListChunked, PlHashMap,
};
use polars_core::schema::Schema;
use polars_core::series::Series;
Expand All @@ -19,7 +19,7 @@ use polars_utils::IdxSize;
use polars_utils::pl_str::PlSmallStr;
use rayon::iter::{IntoParallelIterator, ParallelIterator};

use super::{AggregationContext, PhysicalExpr};
use super::{AggState, AggregationContext, PhysicalExpr};
use crate::state::ExecutionState;

#[derive(Clone)]
Expand All @@ -34,6 +34,7 @@ pub struct EvalExpr {
evaluation_is_scalar: bool,
evaluation_is_elementwise: bool,
evaluation_is_fallible: bool,
uses_ext_columns: bool,
}

impl EvalExpr {
Expand All @@ -49,6 +50,7 @@ impl EvalExpr {
evaluation_is_scalar: bool,
evaluation_is_elementwise: bool,
evaluation_is_fallible: bool,
uses_ext_columns: bool,
) -> Self {
Self {
input,
Expand All @@ -61,6 +63,7 @@ impl EvalExpr {
evaluation_is_scalar,
evaluation_is_elementwise,
evaluation_is_fallible,
uses_ext_columns,
}
}

Expand All @@ -69,6 +72,7 @@ impl EvalExpr {
ca: &ListChunked,
state: &ExecutionState,
is_agg: bool,
ext_df: &DataFrame,
) -> PolarsResult<Column> {
let df = ca.get_inner().with_name(PlSmallStr::EMPTY).into_frame();

Expand All @@ -78,11 +82,14 @@ impl EvalExpr {
return Ok(Column::full_null(name, ca.len(), self.output_field.dtype()));
}

let has_masked_out_values = LazyCell::new(|| ca.has_masked_out_values());
let may_fail_on_masked_out_elements = self.evaluation_is_fallible && *has_masked_out_values;
let has_masked_out_values = ca.has_masked_out_values();
let may_fail_on_masked_out_elements = self.evaluation_is_fallible && has_masked_out_values;

// Fast path: fully elementwise expression without masked out values.
if self.evaluation_is_elementwise && !may_fail_on_masked_out_elements {
if self.evaluation_is_elementwise
&& !self.uses_ext_columns
&& !may_fail_on_masked_out_elements
{
let mut column = self.evaluation.evaluate(&df, state)?;

// Since `lit` is marked as elementwise, this may lead to problems.
Expand Down Expand Up @@ -124,7 +131,17 @@ impl EvalExpr {
};
let groups = Cow::Owned(groups.into_sliceable());

let mut ac = self.evaluation.evaluate_on_groups(&df, &groups, state)?;
let mut state = Cow::Borrowed(state);
if self.uses_ext_columns {
state.to_mut().ext_named_groups =
Arc::new(PlHashMap::from_iter(ext_df.column_iter().map(|c| {
(c.name().clone(), AggState::AggregatedScalar(c.clone()))
})));
}

let mut ac = self
.evaluation
.evaluate_on_groups(&df, &groups, state.as_ref())?;

ac.groups(); // Update the groups.

Expand Down Expand Up @@ -405,11 +422,11 @@ impl PhysicalExpr for EvalExpr {
match self.variant {
EvalVariant::List => {
let lst = input.list()?;
self.evaluate_on_list_chunked(lst, state, false)
self.evaluate_on_list_chunked(lst, state, false, df)
},
EvalVariant::ListAgg => {
let lst = input.list()?;
self.evaluate_on_list_chunked(lst, state, true)
self.evaluate_on_list_chunked(lst, state, true, df)
},
EvalVariant::Array { as_list } => feature_gated!("dtype-array", {
self.evaluate_on_array_chunked(input.array()?, state, as_list, false)
Expand All @@ -433,11 +450,12 @@ impl PhysicalExpr for EvalExpr {
match self.variant {
EvalVariant::List => {
let out =
self.evaluate_on_list_chunked(input.get_values().list()?, state, false)?;
self.evaluate_on_list_chunked(input.get_values().list()?, state, false, df)?;
input.with_values(out, false, Some(&self.expr))?;
},
EvalVariant::ListAgg => {
let out = self.evaluate_on_list_chunked(input.get_values().list()?, state, true)?;
let out =
self.evaluate_on_list_chunked(input.get_values().list()?, state, true, df)?;
input.with_values(out, false, Some(&self.expr))?;
},
EvalVariant::Array { as_list } => feature_gated!("dtype-array", {
Expand Down
9 changes: 8 additions & 1 deletion crates/polars-expr/src/planner.rs
Original file line number Diff line number Diff line change
Expand Up @@ -521,6 +521,11 @@ fn create_physical_expr_inner(
pd_group.update_with_expr_rec(expr_arena.get(*evaluation), expr_arena, None);
let evaluation_is_fallible = matches!(pd_group, ExprPushdownGroup::Fallible);

let has_ext_columns = aexpr_to_leaf_names_iter(*evaluation, expr_arena)
.filter(|n| n.is_empty())
.count()
> 0;

let output_field = expr_arena
.get(expression)
.to_field(&ToFieldContext::new(expr_arena, schema))?;
Expand All @@ -531,7 +536,8 @@ fn create_physical_expr_inner(
create_physical_expr_inner(*expr, Context::Default, expr_arena, schema, state)?;

let element_dtype = variant.element_dtype(&input_field.dtype)?;
let eval_schema = Schema::from_iter([(PlSmallStr::EMPTY, element_dtype.clone())]);
let mut eval_schema = schema.as_ref().clone();
eval_schema.insert(PlSmallStr::EMPTY, element_dtype.clone());
let evaluation = create_physical_expr_inner(
*evaluation,
// @Hack. Since EvalVariant::Array uses `evaluate_on_groups` to determine the
Expand Down Expand Up @@ -561,6 +567,7 @@ fn create_physical_expr_inner(
evaluation_is_scalar,
evaluation_is_elementwise,
evaluation_is_fallible,
has_ext_columns,
)))
},
Function {
Expand Down
5 changes: 5 additions & 0 deletions crates/polars-expr/src/state/execution_state.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ use polars_utils::relaxed_cell::RelaxedCell;
use polars_utils::unique_id::UniqueId;

use super::NodeTimer;
use crate::prelude::{AggState, AggregationContext};

pub type JoinTuplesCache = Arc<Mutex<PlHashMap<String, ChunkJoinOptIds>>>;

Expand Down Expand Up @@ -118,6 +119,8 @@ pub struct ExecutionState {
pub branch_idx: usize,
pub flags: RelaxedCell<u8>,
pub ext_contexts: Arc<Vec<DataFrame>>,
/// External aggregations that can be provided by name.
pub ext_named_groups: Arc<PlHashMap<PlSmallStr, AggState>>,
node_timer: Option<NodeTimer>,
stop: Arc<RelaxedCell<bool>>,
}
Expand All @@ -135,6 +138,7 @@ impl ExecutionState {
branch_idx: 0,
flags: RelaxedCell::from(StateFlags::init().as_u8()),
ext_contexts: Default::default(),
ext_named_groups: Default::default(),
node_timer: None,
stop: Arc::new(RelaxedCell::from(false)),
}
Expand Down Expand Up @@ -199,6 +203,7 @@ impl ExecutionState {
branch_idx: self.branch_idx,
flags: self.flags.clone(),
ext_contexts: self.ext_contexts.clone(),
ext_named_groups: self.ext_named_groups.clone(),
node_timer: self.node_timer.clone(),
stop: self.stop.clone(),
}
Expand Down
2 changes: 2 additions & 0 deletions crates/polars-plan/src/constants.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,14 @@ use polars_utils::pl_str::PlSmallStr;
pub static CSE_REPLACED: &str = "__POLARS_CSER_";
pub static POLARS_TMP_PREFIX: &str = "_POLARS_";
pub static POLARS_PLACEHOLDER: &str = "_POLARS_<>";
pub static POLARS_ELEMENT: &str = "__PL_ELEMENT";
pub const LEN: &str = "len";
const LITERAL_NAME: &str = "literal";

// Cache the often used LITERAL and LEN constants
static LITERAL_NAME_INIT: OnceLock<PlSmallStr> = OnceLock::new();
static LEN_INIT: OnceLock<PlSmallStr> = OnceLock::new();
pub static PL_ELEMENT_NAME: PlSmallStr = PlSmallStr::from_static(POLARS_ELEMENT);

pub fn get_literal_name() -> &'static PlSmallStr {
LITERAL_NAME_INIT.get_or_init(|| PlSmallStr::from_static(LITERAL_NAME))
Expand Down
5 changes: 4 additions & 1 deletion crates/polars-plan/src/plans/aexpr/schema.rs
Original file line number Diff line number Diff line change
Expand Up @@ -248,7 +248,10 @@ impl AExpr {
let field = ctx.arena.get(*expr).to_field_impl(ctx)?;

let element_dtype = variant.element_dtype(field.dtype())?;
let schema = Schema::from_iter([(PlSmallStr::EMPTY, element_dtype.clone())]);
let mut schema = ctx.schema.clone();
schema.insert(PlSmallStr::EMPTY, element_dtype.clone());

dbg!(&schema);

let ctx = ToFieldContext {
schema: &schema,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -785,17 +785,6 @@ fn expand_expression_rec(
evaluation,
variant,
} => {
// Perform this before schema resolution so that we can better error messages.
for e in evaluation.as_ref().into_iter() {
if let Expr::Column(name) = e {
polars_ensure!(
name.is_empty(),
ComputeError:
"named columns are not allowed in `eval` functions; consider using `element`"
);
}
}

let mut tmp = Vec::with_capacity(1);
expand_expression_rec(expr, ignored_selector_columns, schema, &mut tmp, opt_flags)?;

Expand Down
14 changes: 3 additions & 11 deletions crates/polars-plan/src/plans/conversion/dsl_to_ir/expr_to_ir.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
use super::functions::convert_functions;
use super::*;
use crate::constants::PL_ELEMENT_NAME;
use crate::plans::iterator::ArenaExprIter;

pub fn to_expr_ir(expr: Expr, ctx: &mut ExprToIRContext) -> PolarsResult<ExprIR> {
Expand Down Expand Up @@ -416,18 +417,9 @@ pub(super) fn to_aexpr_impl(
let expr_dtype = ctx.arena.get(expr).to_dtype(&ctx.to_field_ctx())?;
let element_dtype = variant.element_dtype(&expr_dtype)?;

// Perform this before schema resolution so that we can better error messages.
for e in evaluation.as_ref().into_iter() {
if let Expr::Column(name) = e {
polars_ensure!(
name.is_empty(),
ComputeError:
"named columns are not allowed in `eval` functions; consider using `element`"
);
}
}
let mut evaluation_schema = ctx.schema.clone();
evaluation_schema.insert(PlSmallStr::EMPTY, element_dtype.clone());

let evaluation_schema = Schema::from_iter([(PlSmallStr::EMPTY, element_dtype.clone())]);
let mut evaluation_ctx = ExprToIRContext {
with_fields: None,
schema: &evaluation_schema,
Expand Down
3 changes: 2 additions & 1 deletion crates/polars-plan/src/plans/conversion/stack_opt.rs
Original file line number Diff line number Diff line change
Expand Up @@ -144,7 +144,8 @@ impl ConversionOptimizer {
.to_dtype(&ToFieldContext::new(expr_arena, schema))?;

let element_dtype = variant.element_dtype(&expr)?;
let schema = Schema::from_iter([(PlSmallStr::EMPTY, element_dtype.clone())]);
let mut schema = schema.clone();
schema.insert(PlSmallStr::EMPTY, element_dtype.clone());
self.schemas.push(schema);
self.scratch.push((*evaluation, self.schemas.len()));
}
Expand Down
Loading