diff --git a/crates/polars-core/src/frame/group_by/mod.rs b/crates/polars-core/src/frame/group_by/mod.rs index c6841ea0bdbc..6955de9811ee 100644 --- a/crates/polars-core/src/frame/group_by/mod.rs +++ b/crates/polars-core/src/frame/group_by/mod.rs @@ -875,6 +875,7 @@ pub enum GroupByMethod { Mean, First, Last, + Item, Sum, Groups, NUnique, @@ -897,6 +898,7 @@ impl Display for GroupByMethod { Mean => "mean", First => "first", Last => "last", + Item => "item", Sum => "sum", Groups => "groups", NUnique => "n_unique", @@ -922,6 +924,7 @@ pub fn fmt_group_by_column(name: &str, method: GroupByMethod) -> PlSmallStr { Mean => format_pl_smallstr!("{name}_mean"), First => format_pl_smallstr!("{name}_first"), Last => format_pl_smallstr!("{name}_last"), + Item => format_pl_smallstr!("{name}_item"), Sum => format_pl_smallstr!("{name}_sum"), Groups => PlSmallStr::from_static("groups"), NUnique => format_pl_smallstr!("{name}_n_unique"), diff --git a/crates/polars-error/src/lib.rs b/crates/polars-error/src/lib.rs index 669641498986..f7dad36ada7e 100644 --- a/crates/polars-error/src/lib.rs +++ b/crates/polars-error/src/lib.rs @@ -504,6 +504,19 @@ on startup."#.trim_start()) ComputeError: "`strptime` / `to_datetime` was called with no format and no time zone, but a time zone is part of the data.\n\nThis was previously allowed but led to unpredictable and erroneous results. Give a format string, set a time zone or perform the operation eagerly on a Series instead of on an Expr." ) }; + (item_agg_count_not_one = $n:expr) => { + if $n == 0 { + polars_err!(ComputeError: + "aggregation 'item' expected a single value, got none" + ) + } else if $n > 1 { + polars_err!(ComputeError: + "aggregation 'item' expected a single value, got {} values", $n + ) + } else { + unreachable!() + } + }; } #[macro_export] diff --git a/crates/polars-expr/src/expressions/aggregation.rs b/crates/polars-expr/src/expressions/aggregation.rs index a1390470c797..a0a8aebc3629 100644 --- a/crates/polars-expr/src/expressions/aggregation.rs +++ b/crates/polars-expr/src/expressions/aggregation.rs @@ -125,6 +125,10 @@ impl PhysicalExpr for AggregationExpr { } else { s.tail(Some(1)) }), + GroupByMethod::Item => Ok(match s.len() { + 1 => s, + n => polars_bail!(item_agg_count_not_one = n), + }), GroupByMethod::Sum => parallel_op_columns( |s| s.sum_reduce().map(|sc| sc.into_column(s.name().clone())), s, @@ -332,6 +336,19 @@ impl PhysicalExpr for AggregationExpr { let agg_s = s.agg_last(&groups); AggregatedScalar(agg_s.with_name(keep_name)) }, + GroupByMethod::Item => { + let (s, groups) = ac.get_final_aggregation(); + for gc in groups.group_count().iter() { + match gc { + None | Some(1) => continue, + Some(n) => { + polars_bail!(item_agg_count_not_one = n); + }, + } + } + let agg_s = s.agg_first(&groups); + AggregatedScalar(agg_s.with_name(keep_name)) + }, GroupByMethod::NUnique => { let (s, groups) = ac.get_final_aggregation(); let agg_s = s.agg_n_unique(&groups); diff --git a/crates/polars-expr/src/planner.rs b/crates/polars-expr/src/planner.rs index dee0463db1a0..2511e595d1d4 100644 --- a/crates/polars-expr/src/planner.rs +++ b/crates/polars-expr/src/planner.rs @@ -378,6 +378,7 @@ fn create_physical_expr_inner( I::NUnique(_) => GBM::NUnique, I::First(_) => GBM::First, I::Last(_) => GBM::Last, + I::Item(_) => GBM::Item, I::Mean(_) => GBM::Mean, I::Implode(_) => GBM::Implode, I::Quantile { .. } => unreachable!(), diff --git a/crates/polars-expr/src/reduce/convert.rs b/crates/polars-expr/src/reduce/convert.rs index 286aa924f78f..7f29be1cedf2 100644 --- a/crates/polars-expr/src/reduce/convert.rs +++ b/crates/polars-expr/src/reduce/convert.rs @@ -11,7 +11,7 @@ use crate::reduce::bitwise::{ new_bitwise_and_reduction, new_bitwise_or_reduction, new_bitwise_xor_reduction, }; use crate::reduce::count::{CountReduce, NullCountReduce}; -use crate::reduce::first_last::{new_first_reduction, new_last_reduction}; +use crate::reduce::first_last::{new_first_reduction, new_item_reduction, new_last_reduction}; use crate::reduce::len::LenReduce; use crate::reduce::mean::new_mean_reduction; use crate::reduce::min_max::{new_max_reduction, new_min_reduction}; @@ -51,6 +51,7 @@ pub fn into_reduction( }, IRAggExpr::First(input) => (new_first_reduction(get_dt(*input)?), *input), IRAggExpr::Last(input) => (new_last_reduction(get_dt(*input)?), *input), + IRAggExpr::Item(input) => (new_item_reduction(get_dt(*input)?), *input), IRAggExpr::Count { input, include_nulls, diff --git a/crates/polars-expr/src/reduce/first_last.rs b/crates/polars-expr/src/reduce/first_last.rs index 83db6a724fc9..ae781738547e 100644 --- a/crates/polars-expr/src/reduce/first_last.rs +++ b/crates/polars-expr/src/reduce/first_last.rs @@ -1,4 +1,5 @@ #![allow(unsafe_op_in_unsafe_fn)] +use std::fmt::Debug; use std::marker::PhantomData; use polars_core::frame::row::AnyValueBufferTrusted; @@ -14,6 +15,10 @@ pub fn new_last_reduction(dtype: DataType) -> Box { new_reduction_with_policy::(dtype) } +pub fn new_item_reduction(dtype: DataType) -> Box { + new_reduction_with_policy::(dtype) +} + fn new_reduction_with_policy(dtype: DataType) -> Box { use DataType::*; use VecGroupedReduction as VGR; @@ -42,6 +47,9 @@ fn new_reduction_with_policy(dtype: DataType) -> Box usize; fn should_replace(new: u64, old: u64) -> bool; + fn is_item_policy() -> bool { + false + } } struct First; @@ -68,9 +76,8 @@ impl Policy for Last { } } -#[allow(dead_code)] -struct Arbitrary; -impl Policy for Arbitrary { +struct Item; +impl Policy for Item { fn index(_len: usize) -> usize { 0 } @@ -78,10 +85,21 @@ impl Policy for Arbitrary { fn should_replace(_new: u64, old: u64) -> bool { old == 0 } + + fn is_item_policy() -> bool { + true + } } struct NumFirstLastReducer(PhantomData<(P, T)>); +#[derive(Clone, Debug, Default)] +struct Value { + value: Option, + seq: u64, + count: u64, +} + impl Clone for NumFirstLastReducer { fn clone(&self) -> Self { Self(PhantomData) @@ -94,10 +112,10 @@ where T: PolarsNumericType, { type Dtype = T; - type Value = (Option, u64); + type Value = Value; fn init(&self) -> Self::Value { - (None, 0) + Value::default() } fn cast_series<'a>(&self, s: &'a Series) -> Cow<'a, Series> { @@ -105,22 +123,28 @@ where } fn combine(&self, a: &mut Self::Value, b: &Self::Value) { - if P::should_replace(b.1, a.1) { - *a = *b; + if P::should_replace(b.seq, a.seq) { + a.value = b.value; + a.seq = b.seq; } + a.count += b.count; } fn reduce_one(&self, a: &mut Self::Value, b: Option, seq_id: u64) { - if P::should_replace(seq_id, a.1) { - *a = (b, seq_id); + if P::should_replace(seq_id, a.seq) { + a.value = b; + a.seq = seq_id; } + a.count += b.is_some() as u64; } fn reduce_ca(&self, v: &mut Self::Value, ca: &ChunkedArray, seq_id: u64) { - if !ca.is_empty() && P::should_replace(seq_id, v.1) { + if !ca.is_empty() && P::should_replace(seq_id, v.seq) { let val = ca.get(P::index(ca.len())); - *v = (val, seq_id); + v.value = val; + v.seq = seq_id; } + v.count += ca.len() as u64; } fn finish( @@ -130,7 +154,13 @@ where dtype: &DataType, ) -> PolarsResult { assert!(m.is_none()); // This should only be used with VecGroupedReduction. - let ca: ChunkedArray = v.into_iter().map(|(x, _s)| x).collect_ca(PlSmallStr::EMPTY); + if P::is_item_policy() { + check_item_count_is_one(&v)?; + } + let ca: ChunkedArray = v + .into_iter() + .map(|red_val| red_val.value) + .collect_ca(PlSmallStr::EMPTY); let s = ca.into_series(); unsafe { s.from_physical_unchecked(dtype) } } @@ -159,10 +189,10 @@ where P: Policy, { type Dtype = BinaryType; - type Value = (Option>, u64); + type Value = Value>; fn init(&self) -> Self::Value { - (None, 0) + Value::default() } fn cast_series<'a>(&self, s: &'a Series) -> Cow<'a, Series> { @@ -170,24 +200,27 @@ where } fn combine(&self, a: &mut Self::Value, b: &Self::Value) { - if P::should_replace(b.1, a.1) { - a.0.clone_from(&b.0); - a.1 = b.1; + if P::should_replace(b.seq, a.seq) { + a.value.clone_from(&b.value); + a.seq = b.seq; } + a.count += b.count; } fn reduce_one(&self, a: &mut Self::Value, b: Option<&[u8]>, seq_id: u64) { - if P::should_replace(seq_id, a.1) { - replace_opt_bytes(&mut a.0, b); - a.1 = seq_id; + if P::should_replace(seq_id, a.seq) { + replace_opt_bytes(&mut a.value, b); + a.seq = seq_id; } + a.count += b.is_some() as u64; } fn reduce_ca(&self, v: &mut Self::Value, ca: &ChunkedArray, seq_id: u64) { - if !ca.is_empty() && P::should_replace(seq_id, v.1) { - replace_opt_bytes(&mut v.0, ca.get(P::index(ca.len()))); - v.1 = seq_id; + if !ca.is_empty() && P::should_replace(seq_id, v.seq) { + replace_opt_bytes(&mut v.value, ca.get(P::index(ca.len()))); + v.seq = seq_id; } + v.count += ca.len() as u64; } fn finish( @@ -197,7 +230,13 @@ where dtype: &DataType, ) -> PolarsResult { assert!(m.is_none()); // This should only be used with VecGroupedReduction. - let ca: BinaryChunked = v.into_iter().map(|(x, _s)| x).collect_ca(PlSmallStr::EMPTY); + if P::is_item_policy() { + check_item_count_is_one(&v)?; + } + let ca: BinaryChunked = v + .into_iter() + .map(|Value { value, .. }| value) + .collect_ca(PlSmallStr::EMPTY); ca.into_series().cast(dtype) } } @@ -215,30 +254,34 @@ where P: Policy, { type Dtype = BooleanType; - type Value = (Option, u64); + type Value = Value; fn init(&self) -> Self::Value { - (None, 0) + Value::default() } fn combine(&self, a: &mut Self::Value, b: &Self::Value) { - if P::should_replace(b.1, a.1) { - *a = *b; + if P::should_replace(b.seq, a.seq) { + a.value = b.value; + a.seq = b.seq; } + a.count += b.count; } fn reduce_one(&self, a: &mut Self::Value, b: Option, seq_id: u64) { - if P::should_replace(seq_id, a.1) { - a.0 = b; - a.1 = seq_id; + if P::should_replace(seq_id, a.seq) { + a.value = b; + a.seq = seq_id; } + a.count += b.is_some() as u64; } fn reduce_ca(&self, v: &mut Self::Value, ca: &ChunkedArray, seq_id: u64) { - if !ca.is_empty() && P::should_replace(seq_id, v.1) { - v.0 = ca.get(P::index(ca.len())); - v.1 = seq_id; + if !ca.is_empty() && P::should_replace(seq_id, v.seq) { + v.value = ca.get(P::index(ca.len())); + v.seq = seq_id; } + v.count += ca.len() as u64; } fn finish( @@ -248,7 +291,13 @@ where _dtype: &DataType, ) -> PolarsResult { assert!(m.is_none()); // This should only be used with VecGroupedReduction. - let ca: BooleanChunked = v.into_iter().map(|(x, _s)| x).collect_ca(PlSmallStr::EMPTY); + if P::is_item_policy() { + check_item_count_is_one(&v)?; + } + let ca: BooleanChunked = v + .into_iter() + .map(|Value { value, .. }| value) + .collect_ca(PlSmallStr::EMPTY); Ok(ca.into_series()) } } @@ -257,8 +306,10 @@ pub struct GenericFirstLastGroupedReduction

{ in_dtype: DataType, values: Vec>, seqs: Vec, + counts: Vec, evicted_values: Vec>, evicted_seqs: Vec, + evicted_counts: Vec, policy: PhantomData P>, } @@ -268,8 +319,10 @@ impl

GenericFirstLastGroupedReduction

{ in_dtype, values: Vec::new(), seqs: Vec::new(), + counts: Vec::new(), evicted_values: Vec::new(), evicted_seqs: Vec::new(), + evicted_counts: Vec::new(), policy: PhantomData, } } @@ -283,11 +336,13 @@ impl GroupedReduction for GenericFirstLastGroupedReduction< fn reserve(&mut self, additional: usize) { self.values.reserve(additional); self.seqs.reserve(additional); + self.counts.reserve(additional); } fn resize(&mut self, num_groups: IdxSize) { self.values.resize(num_groups as usize, AnyValue::Null); self.seqs.resize(num_groups as usize, 0); + self.counts.resize(num_groups as usize, 0); } fn update_group( @@ -303,6 +358,7 @@ impl GroupedReduction for GenericFirstLastGroupedReduction< self.values[group_idx as usize] = values.get(P::index(values.len()))?.into_static(); self.seqs[group_idx as usize] = seq_id; } + self.counts[group_idx as usize] += values.len() as u64; } Ok(()) } @@ -320,15 +376,18 @@ impl GroupedReduction for GenericFirstLastGroupedReduction< for (i, g) in subset.iter().zip(group_idxs) { let grp_val = self.values.get_unchecked_mut(g.idx()); let grp_seq = self.seqs.get_unchecked_mut(g.idx()); + let grp_count = self.counts.get_unchecked_mut(g.idx()); if g.should_evict() { self.evicted_values .push(core::mem::replace(grp_val, AnyValue::Null)); self.evicted_seqs.push(core::mem::replace(grp_seq, 0)); + self.evicted_counts.push(core::mem::replace(grp_count, 0)); } if P::should_replace(seq_id, *grp_seq) { *grp_val = values.get_unchecked(*i as usize).into_static(); *grp_seq = seq_id; } + *self.counts.get_unchecked_mut(g.idx()) += 1; } Ok(()) } @@ -352,6 +411,7 @@ impl GroupedReduction for GenericFirstLastGroupedReduction< other.values.get_unchecked(si).clone(); *self.seqs.get_unchecked_mut(*g as usize) = *other.seqs.get_unchecked(si); } + *self.counts.get_unchecked_mut(*g as usize) += *other.counts.get_unchecked(si); } Ok(()) } @@ -361,14 +421,21 @@ impl GroupedReduction for GenericFirstLastGroupedReduction< in_dtype: self.in_dtype.clone(), values: core::mem::take(&mut self.evicted_values), seqs: core::mem::take(&mut self.evicted_seqs), + counts: core::mem::take(&mut self.evicted_counts), evicted_values: Vec::new(), evicted_seqs: Vec::new(), + evicted_counts: Vec::new(), policy: PhantomData, }) } fn finalize(&mut self) -> PolarsResult { self.seqs.clear(); + if P::is_item_policy() { + for count in self.counts.iter() { + polars_ensure!(*count == 1, item_agg_count_not_one = *count); + } + } unsafe { let mut buf = AnyValueBufferTrusted::new(&self.in_dtype, self.values.len()); for v in core::mem::take(&mut self.values) { @@ -382,3 +449,10 @@ impl GroupedReduction for GenericFirstLastGroupedReduction< self } } + +fn check_item_count_is_one(v: &[Value]) -> PolarsResult<()> { + if let Some(Value { count: n, .. }) = v.iter().find(|v| v.count != 1) { + polars_bail!(item_agg_count_not_one = *n); + } + Ok(()) +} diff --git a/crates/polars-plan/dsl-schema-hashes.json b/crates/polars-plan/dsl-schema-hashes.json index 831dc0d31e87..1a0043211a25 100644 --- a/crates/polars-plan/dsl-schema-hashes.json +++ b/crates/polars-plan/dsl-schema-hashes.json @@ -1,5 +1,5 @@ { - "AggExpr": "5398ac46a31d511fa6c645556c45b3ebeba6544df2629cabac079230822b1130", + "AggExpr": "2bdb1e6f50f333246ea8eb2d2139a2fe8f9b4b638160331c3f28fac186471544", "AnonymousColumnsUdf": "04e8b658fac4f09f7f9607c73be6fd3fe258064dd33468710f2c3e188c281a69", "AnyValue": "ef2b7f7588918138f192b3545a8474915a90d211b7c786e642427b5cd565d4ef", "ArrayDataTypeFunction": "f6606e9a91efce34563b32adb32473cd19d8c1e9b184b102be72268d14306136", diff --git a/crates/polars-plan/src/dsl/expr/mod.rs b/crates/polars-plan/src/dsl/expr/mod.rs index a504918a5d3e..90c8713429ab 100644 --- a/crates/polars-plan/src/dsl/expr/mod.rs +++ b/crates/polars-plan/src/dsl/expr/mod.rs @@ -37,6 +37,7 @@ pub enum AggExpr { NUnique(Arc), First(Arc), Last(Arc), + Item(Arc), Mean(Arc), Implode(Arc), Count { @@ -64,6 +65,7 @@ impl AsRef for AggExpr { NUnique(e) => e, First(e) => e, Last(e) => e, + Item(e) => e, Mean(e) => e, Implode(e) => e, Count { input, .. } => input, diff --git a/crates/polars-plan/src/dsl/format.rs b/crates/polars-plan/src/dsl/format.rs index c5295295ebdc..aecca2d16cc6 100644 --- a/crates/polars-plan/src/dsl/format.rs +++ b/crates/polars-plan/src/dsl/format.rs @@ -113,6 +113,7 @@ impl fmt::Debug for Expr { Mean(expr) => write!(f, "{expr:?}.mean()"), First(expr) => write!(f, "{expr:?}.first()"), Last(expr) => write!(f, "{expr:?}.last()"), + Item(expr) => write!(f, "{expr:?}.item()"), Implode(expr) => write!(f, "{expr:?}.list()"), NUnique(expr) => write!(f, "{expr:?}.n_unique()"), Sum(expr) => write!(f, "{expr:?}.sum()"), diff --git a/crates/polars-plan/src/dsl/mod.rs b/crates/polars-plan/src/dsl/mod.rs index f8f7f7ce4412..fd1c7610c73a 100644 --- a/crates/polars-plan/src/dsl/mod.rs +++ b/crates/polars-plan/src/dsl/mod.rs @@ -176,6 +176,11 @@ impl Expr { AggExpr::Last(Arc::new(self)).into() } + /// Get the single value in the group. If there are multiple values, an error is returned. + pub fn item(self) -> Self { + AggExpr::Item(Arc::new(self)).into() + } + /// GroupBy the group to a Series. pub fn implode(self) -> Self { AggExpr::Implode(Arc::new(self)).into() diff --git a/crates/polars-plan/src/plans/aexpr/equality.rs b/crates/polars-plan/src/plans/aexpr/equality.rs index 0f0038aa666a..9565e5d2dd17 100644 --- a/crates/polars-plan/src/plans/aexpr/equality.rs +++ b/crates/polars-plan/src/plans/aexpr/equality.rs @@ -111,6 +111,7 @@ impl IRAggExpr { A::NUnique(_) | A::First(_) | A::Last(_) | + A::Item(_) | A::Mean(_) | A::Implode(_) | A::Sum(_) | diff --git a/crates/polars-plan/src/plans/aexpr/mod.rs b/crates/polars-plan/src/plans/aexpr/mod.rs index 9603bf3ff0e5..78a42036dcb8 100644 --- a/crates/polars-plan/src/plans/aexpr/mod.rs +++ b/crates/polars-plan/src/plans/aexpr/mod.rs @@ -48,6 +48,7 @@ pub enum IRAggExpr { NUnique(Node), First(Node), Last(Node), + Item(Node), Mean(Node), Implode(Node), Quantile { @@ -146,6 +147,7 @@ impl From for GroupByMethod { NUnique(_) => GroupByMethod::NUnique, First(_) => GroupByMethod::First, Last(_) => GroupByMethod::Last, + Item(_) => GroupByMethod::Item, Mean(_) => GroupByMethod::Mean, Implode(_) => GroupByMethod::Implode, Sum(_) => GroupByMethod::Sum, diff --git a/crates/polars-plan/src/plans/aexpr/schema.rs b/crates/polars-plan/src/plans/aexpr/schema.rs index 80d2c793b595..21d36aa05f5b 100644 --- a/crates/polars-plan/src/plans/aexpr/schema.rs +++ b/crates/polars-plan/src/plans/aexpr/schema.rs @@ -139,7 +139,8 @@ impl AExpr { Max { input: expr, .. } | Min { input: expr, .. } | First(expr) - | Last(expr) => ctx.arena.get(*expr).to_field_impl(ctx), + | Last(expr) + | Item(expr) => ctx.arena.get(*expr).to_field_impl(ctx), Sum(expr) => { let mut field = ctx.arena.get(*expr).to_field_impl(ctx)?; let dt = match field.dtype() { @@ -319,6 +320,7 @@ impl AExpr { | Agg(Min { input: expr, .. }) | Agg(First(expr)) | Agg(Last(expr)) + | Agg(Item(expr)) | Agg(Sum(expr)) | Agg(Median(expr)) | Agg(Mean(expr)) diff --git a/crates/polars-plan/src/plans/aexpr/traverse.rs b/crates/polars-plan/src/plans/aexpr/traverse.rs index ffec8d5c861f..64a75980127e 100644 --- a/crates/polars-plan/src/plans/aexpr/traverse.rs +++ b/crates/polars-plan/src/plans/aexpr/traverse.rs @@ -245,6 +245,7 @@ impl IRAggExpr { pub fn get_input(&self) -> NodeInputs { use IRAggExpr::*; use NodeInputs::*; + match self { Min { input, .. } => Single(*input), Max { input, .. } => Single(*input), @@ -252,6 +253,7 @@ impl IRAggExpr { NUnique(input) => Single(*input), First(input) => Single(*input), Last(input) => Single(*input), + Item(input) => Single(*input), Mean(input) => Single(*input), Implode(input) => Single(*input), Quantile { expr, quantile, .. } => Many(vec![*expr, *quantile]), @@ -271,6 +273,7 @@ impl IRAggExpr { NUnique(input) => input, First(input) => input, Last(input) => input, + Item(input) => input, Mean(input) => input, Implode(input) => input, Quantile { expr, .. } => expr, diff --git a/crates/polars-plan/src/plans/conversion/dsl_to_ir/expr_expansion.rs b/crates/polars-plan/src/plans/conversion/dsl_to_ir/expr_expansion.rs index a5c75cb4911f..9f72f8315138 100644 --- a/crates/polars-plan/src/plans/conversion/dsl_to_ir/expr_expansion.rs +++ b/crates/polars-plan/src/plans/conversion/dsl_to_ir/expr_expansion.rs @@ -467,6 +467,14 @@ fn expand_expression_rec( opt_flags, |e| Expr::Agg(AggExpr::Last(Arc::new(e))), )?, + AggExpr::Item(expr) => expand_single( + expr.as_ref(), + ignored_selector_columns, + schema, + out, + opt_flags, + |e| Expr::Agg(AggExpr::Item(Arc::new(e))), + )?, AggExpr::Mean(expr) => expand_single( expr.as_ref(), ignored_selector_columns, diff --git a/crates/polars-plan/src/plans/conversion/dsl_to_ir/expr_to_ir.rs b/crates/polars-plan/src/plans/conversion/dsl_to_ir/expr_to_ir.rs index 4c952b33b1f6..9a5441693779 100644 --- a/crates/polars-plan/src/plans/conversion/dsl_to_ir/expr_to_ir.rs +++ b/crates/polars-plan/src/plans/conversion/dsl_to_ir/expr_to_ir.rs @@ -253,6 +253,10 @@ pub(super) fn to_aexpr_impl( let (input, output_name) = to_aexpr_mat_lit_arc!(input)?; (IRAggExpr::Last(input), output_name) }, + AggExpr::Item(input) => { + let (input, output_name) = to_aexpr_mat_lit_arc!(input)?; + (IRAggExpr::Item(input), output_name) + }, AggExpr::Mean(input) => { let (input, output_name) = to_aexpr_mat_lit_arc!(input)?; (IRAggExpr::Mean(input), output_name) diff --git a/crates/polars-plan/src/plans/conversion/ir_to_dsl.rs b/crates/polars-plan/src/plans/conversion/ir_to_dsl.rs index 1c551afd2009..dc4558e087af 100644 --- a/crates/polars-plan/src/plans/conversion/ir_to_dsl.rs +++ b/crates/polars-plan/src/plans/conversion/ir_to_dsl.rs @@ -122,6 +122,10 @@ pub fn node_to_expr(node: Node, expr_arena: &Arena) -> Expr { let exp = node_to_expr(expr, expr_arena); AggExpr::Last(Arc::new(exp)).into() }, + IRAggExpr::Item(expr) => { + let exp = node_to_expr(expr, expr_arena); + AggExpr::Item(Arc::new(exp)).into() + }, IRAggExpr::Implode(expr) => { let exp = node_to_expr(expr, expr_arena); AggExpr::Implode(Arc::new(exp)).into() diff --git a/crates/polars-plan/src/plans/ir/format.rs b/crates/polars-plan/src/plans/ir/format.rs index 257d2c2b6a5d..fd88daa4096b 100644 --- a/crates/polars-plan/src/plans/ir/format.rs +++ b/crates/polars-plan/src/plans/ir/format.rs @@ -452,6 +452,7 @@ impl Display for ExprIRDisplay<'_> { Mean(expr) => write!(f, "{}.mean()", self.with_root(expr)), First(expr) => write!(f, "{}.first()", self.with_root(expr)), Last(expr) => write!(f, "{}.last()", self.with_root(expr)), + Item(expr) => write!(f, "{}.item()", self.with_root(expr)), Implode(expr) => write!(f, "{}.implode()", self.with_root(expr)), NUnique(expr) => write!(f, "{}.n_unique()", self.with_root(expr)), Sum(expr) => write!(f, "{}.sum()", self.with_root(expr)), diff --git a/crates/polars-plan/src/plans/iterator.rs b/crates/polars-plan/src/plans/iterator.rs index bb8821c35038..31db4d1458f8 100644 --- a/crates/polars-plan/src/plans/iterator.rs +++ b/crates/polars-plan/src/plans/iterator.rs @@ -48,6 +48,7 @@ macro_rules! push_expr { NUnique(e) => $push($c, e), First(e) => $push($c, e), Last(e) => $push($c, e), + Item(e) => $push($c, e), Implode(e) => $push($c, e), Count { input, .. } => $push($c, input), Quantile { expr, .. } => $push($c, expr), diff --git a/crates/polars-plan/src/plans/optimizer/set_order/expr_pushdown.rs b/crates/polars-plan/src/plans/optimizer/set_order/expr_pushdown.rs index 9b0f22b22a41..4c8eda692182 100644 --- a/crates/polars-plan/src/plans/optimizer/set_order/expr_pushdown.rs +++ b/crates/polars-plan/src/plans/optimizer/set_order/expr_pushdown.rs @@ -208,7 +208,8 @@ impl<'a> ObservableOrdersResolver<'a> { | IRAggExpr::Sum(node) | IRAggExpr::Count { input: node, .. } | IRAggExpr::Std(node, _) - | IRAggExpr::Var(node, _) => { + | IRAggExpr::Var(node, _) + | IRAggExpr::Item(node) => { // Input order is deregarded, but must not observe order. _ = rec!(*node); O::None diff --git a/crates/polars-plan/src/plans/visitor/expr.rs b/crates/polars-plan/src/plans/visitor/expr.rs index ca96d72d56de..fbe3dd689f83 100644 --- a/crates/polars-plan/src/plans/visitor/expr.rs +++ b/crates/polars-plan/src/plans/visitor/expr.rs @@ -60,6 +60,7 @@ impl TreeWalker for Expr { NUnique(x) => NUnique(am(x, f)?), First(x) => First(am(x, f)?), Last(x) => Last(am(x, f)?), + Item(x) => Item(am(x, f)?), Mean(x) => Mean(am(x, f)?), Implode(x) => Implode(am(x, f)?), Count { input, include_nulls } => Count { input: am(input, f)?, include_nulls }, diff --git a/crates/polars-python/src/expr/general.rs b/crates/polars-python/src/expr/general.rs index 0dfc7b18fb4d..398f8e7231a8 100644 --- a/crates/polars-python/src/expr/general.rs +++ b/crates/polars-python/src/expr/general.rs @@ -152,6 +152,9 @@ impl PyExpr { fn last(&self) -> Self { self.inner.clone().last().into() } + fn item(&self) -> Self { + self.inner.clone().item().into() + } fn implode(&self) -> Self { self.inner.clone().implode().into() } diff --git a/crates/polars-python/src/lazyframe/visitor/expr_nodes.rs b/crates/polars-python/src/lazyframe/visitor/expr_nodes.rs index 0c534be2fb31..b674cf09005b 100644 --- a/crates/polars-python/src/lazyframe/visitor/expr_nodes.rs +++ b/crates/polars-python/src/lazyframe/visitor/expr_nodes.rs @@ -684,6 +684,11 @@ pub(crate) fn into_py(py: Python<'_>, expr: &AExpr) -> PyResult> { arguments: vec![n.0], options: py.None(), }, + IRAggExpr::Item(n) => Agg { + name: "item".into_py_any(py)?, + arguments: vec![n.0], + options: py.None(), + }, IRAggExpr::Mean(n) => Agg { name: "mean".into_py_any(py)?, arguments: vec![n.0], diff --git a/crates/polars-stream/src/physical_plan/lower_expr.rs b/crates/polars-stream/src/physical_plan/lower_expr.rs index c8699dabbcf9..c0070a901f30 100644 --- a/crates/polars-stream/src/physical_plan/lower_expr.rs +++ b/crates/polars-stream/src/physical_plan/lower_expr.rs @@ -1661,6 +1661,7 @@ fn lower_exprs_with_ctx( | IRAggExpr::Max { .. } | IRAggExpr::First(_) | IRAggExpr::Last(_) + | IRAggExpr::Item(_) | IRAggExpr::Sum(_) | IRAggExpr::Mean(_) | IRAggExpr::Var { .. } diff --git a/crates/polars-stream/src/physical_plan/lower_group_by.rs b/crates/polars-stream/src/physical_plan/lower_group_by.rs index 732e825aff3c..8a8c6b29b1f8 100644 --- a/crates/polars-stream/src/physical_plan/lower_group_by.rs +++ b/crates/polars-stream/src/physical_plan/lower_group_by.rs @@ -312,6 +312,7 @@ fn try_lower_elementwise_scalar_agg_expr( | IRAggExpr::Max { .. } | IRAggExpr::First(_) | IRAggExpr::Last(_) + | IRAggExpr::Item(_) | IRAggExpr::Mean(_) | IRAggExpr::Sum(_) | IRAggExpr::Var(..) diff --git a/py-polars/docs/source/reference/expressions/list.rst b/py-polars/docs/source/reference/expressions/list.rst index d0889614cf3c..bd5e2645a4e3 100644 --- a/py-polars/docs/source/reference/expressions/list.rst +++ b/py-polars/docs/source/reference/expressions/list.rst @@ -26,6 +26,7 @@ The following methods are available under the `expr.list` attribute. Expr.list.gather_every Expr.list.get Expr.list.head + Expr.list.item Expr.list.join Expr.list.last Expr.list.len diff --git a/py-polars/docs/source/reference/expressions/modify_select.rst b/py-polars/docs/source/reference/expressions/modify_select.rst index 73b9aaee9b5f..0542f17a344d 100644 --- a/py-polars/docs/source/reference/expressions/modify_select.rst +++ b/py-polars/docs/source/reference/expressions/modify_select.rst @@ -33,6 +33,7 @@ Manipulation/selection Expr.inspect Expr.interpolate Expr.interpolate_by + Expr.item Expr.limit Expr.lower_bound Expr.pipe diff --git a/py-polars/docs/source/reference/series/list.rst b/py-polars/docs/source/reference/series/list.rst index d51fb1470bb2..b5dffa265603 100644 --- a/py-polars/docs/source/reference/series/list.rst +++ b/py-polars/docs/source/reference/series/list.rst @@ -26,6 +26,7 @@ The following methods are available under the `Series.list` attribute. Series.list.gather_every Series.list.get Series.list.head + Series.list.item Series.list.join Series.list.last Series.list.len diff --git a/py-polars/src/polars/_plr.pyi b/py-polars/src/polars/_plr.pyi index f11dc7608560..ccd2975702a0 100644 --- a/py-polars/src/polars/_plr.pyi +++ b/py-polars/src/polars/_plr.pyi @@ -1184,6 +1184,7 @@ class PyExpr: def unique_stable(self) -> PyExpr: ... def first(self) -> PyExpr: ... def last(self) -> PyExpr: ... + def item(self) -> PyExpr: ... def implode(self) -> PyExpr: ... def quantile(self, quantile: PyExpr, interpolation: Any) -> PyExpr: ... def cut( diff --git a/py-polars/src/polars/expr/expr.py b/py-polars/src/polars/expr/expr.py index dd87a1318ac0..a4ddc2b82f9a 100644 --- a/py-polars/src/polars/expr/expr.py +++ b/py-polars/src/polars/expr/expr.py @@ -3443,6 +3443,37 @@ def last(self) -> Expr: """ return wrap_expr(self._pyexpr.last()) + @unstable() + def item(self) -> Expr: + """ + Get the single value. + + This raises an error if there is not exactly one value. + + See Also + -------- + :meth:`Expr.get` : Get a single value by index. + + Examples + -------- + >>> df = pl.DataFrame({"a": [1]}) + >>> df.select(pl.col("a").item()) + shape: (1, 1) + ┌─────┐ + │ a │ + │ --- │ + │ i64 │ + ╞═════╡ + │ 1 │ + └─────┘ + >>> df = pl.DataFrame({"a": [1, 2, 3]}) + >>> df.select(pl.col("a").item()) + Traceback (most recent call last): + ... + polars.exceptions.ComputeError: aggregation 'item' expected a single value, got 3 values + """ # noqa: W505 + return wrap_expr(self._pyexpr.item()) + def over( self, partition_by: IntoExpr | Iterable[IntoExpr] | None = None, diff --git a/py-polars/src/polars/expr/list.py b/py-polars/src/polars/expr/list.py index c5f68cfac23a..5391d8b5ecbb 100644 --- a/py-polars/src/polars/expr/list.py +++ b/py-polars/src/polars/expr/list.py @@ -8,6 +8,7 @@ from polars import exceptions from polars import functions as F from polars._utils.parse import parse_into_expression +from polars._utils.unstable import unstable from polars._utils.various import issue_warning from polars._utils.wrap import wrap_expr @@ -683,6 +684,39 @@ def last(self) -> Expr: """ return self.get(-1, null_on_oob=True) + @unstable() + def item(self) -> Expr: + """ + Get the single value of the sublists. + + This errors if the sublist length is not exactly one. + + See Also + -------- + :meth:`Expr.list.get` : Get the value by index in the sublists. + + Examples + -------- + >>> df = pl.DataFrame({"a": [[3], [1], [2]]}) + >>> df.with_columns(item=pl.col("a").list.item()) + shape: (3, 2) + ┌───────────┬──────┐ + │ a ┆ item │ + │ --- ┆ --- │ + │ list[i64] ┆ i64 │ + ╞═══════════╪══════╡ + │ [3] ┆ 3 │ + │ [1] ┆ 1 │ + │ [2] ┆ 2 │ + └───────────┴──────┘ + >>> df = pl.DataFrame({"a": [[3, 2, 1], [1], [2]]}) + >>> df.select(pl.col("a").list.item()) + Traceback (most recent call last): + ... + polars.exceptions.ComputeError: aggregation 'item' expected a single value, got 3 values + """ # noqa: W505 + return self.agg(F.element().item()) + def contains(self, item: IntoExpr, *, nulls_equal: bool = True) -> Expr: """ Check if sublists contain the given item. diff --git a/py-polars/src/polars/series/list.py b/py-polars/src/polars/series/list.py index ea0755805488..425fffe9c370 100644 --- a/py-polars/src/polars/series/list.py +++ b/py-polars/src/polars/series/list.py @@ -4,6 +4,7 @@ from typing import TYPE_CHECKING, Any, Callable from polars import functions as F +from polars._utils.unstable import unstable from polars._utils.wrap import wrap_s from polars.series.utils import expr_dispatch @@ -570,6 +571,35 @@ def last(self) -> Series: ] """ + @unstable() + def item(self) -> Series: + """ + Get the single value of the sublists. + + This errors if the sublist length is not exactly one. + + See Also + -------- + :meth:`Series.list.get` : Get the value by index in the sublists. + + Examples + -------- + >>> s = pl.Series("a", [[1], [4], [6]]) + >>> s.list.item() + shape: (3,) + Series: 'a' [i64] + [ + 1 + 4 + 6 + ] + >>> df = pl.Series("a", [[3, 2, 1], [1], [2]]) + >>> df.list.item() + Traceback (most recent call last): + ... + polars.exceptions.ComputeError: aggregation 'item' expected a single value, got 3 values + """ # noqa: W505 + def contains(self, item: IntoExpr, *, nulls_equal: bool = True) -> Series: """ Check if sublists contain the given item. diff --git a/py-polars/tests/unit/operations/aggregation/test_aggregations.py b/py-polars/tests/unit/operations/aggregation/test_aggregations.py index 533faa0151ca..877c371fa998 100644 --- a/py-polars/tests/unit/operations/aggregation/test_aggregations.py +++ b/py-polars/tests/unit/operations/aggregation/test_aggregations.py @@ -5,10 +5,12 @@ import numpy as np import pytest +from hypothesis import given import polars as pl from polars.exceptions import InvalidOperationError from polars.testing import assert_frame_equal +from polars.testing.parametric import dataframes if TYPE_CHECKING: import numpy.typing as npt @@ -939,3 +941,77 @@ def test_invalid_agg_dtypes_should_raise( pl.exceptions.PolarsError, match=rf"`{op}` operation not supported for dtype" ): df.lazy().select(expr).collect(engine="streaming") + + +@given( + df=dataframes( + min_size=1, + max_size=1, + excluded_dtypes=[ + # TODO: polars/#24936 + pl.Struct, + ], + ) +) +def test_single(df: pl.DataFrame) -> None: + q = df.lazy().select(pl.all(ignore_nulls=False).item()) + assert_frame_equal(q.collect(), df) + assert_frame_equal(q.collect(engine="streaming"), df) + + +@given(df=dataframes(max_size=0)) +def test_single_empty(df: pl.DataFrame) -> None: + q = df.lazy().select(pl.all().item()) + match = "aggregation 'item' expected a single value, got none" + with pytest.raises(pl.exceptions.ComputeError, match=match): + q.collect() + with pytest.raises(pl.exceptions.ComputeError, match=match): + q.collect(engine="streaming") + + +@given(df=dataframes(min_size=2)) +def test_item_too_many(df: pl.DataFrame) -> None: + q = df.lazy().select(pl.all(ignore_nulls=False).item()) + match = f"aggregation 'item' expected a single value, got {df.height} values" + with pytest.raises(pl.exceptions.ComputeError, match=match): + q.collect() + with pytest.raises(pl.exceptions.ComputeError, match=match): + q.collect(engine="streaming") + + +@given( + df=dataframes( + min_size=1, + max_size=1, + allow_null=False, + excluded_dtypes=[ + # TODO: polars/#24936 + pl.Struct, + ], + ) +) +def test_item_on_groups(df: pl.DataFrame) -> None: + df = df.with_columns(pl.col("col0").alias("key")) + q = df.lazy().group_by("col0").agg(pl.all(ignore_nulls=False).item()) + assert_frame_equal(q.collect(), df) + assert_frame_equal(q.collect(engine="streaming"), df) + + +def test_item_on_groups_empty() -> None: + df = pl.DataFrame({"col0": [[]]}) + q = df.lazy().select(pl.all().list.item()) + match = "aggregation 'item' expected a single value, got none" + with pytest.raises(pl.exceptions.ComputeError, match=match): + q.collect() + with pytest.raises(pl.exceptions.ComputeError, match=match): + q.collect(engine="streaming") + + +def test_item_on_groups_too_many() -> None: + df = pl.DataFrame({"col0": [[1, 2, 3]]}) + q = df.lazy().select(pl.all().list.item()) + match = "aggregation 'item' expected a single value, got 3 values" + with pytest.raises(pl.exceptions.ComputeError, match=match): + q.collect() + with pytest.raises(pl.exceptions.ComputeError, match=match): + q.collect(engine="streaming") diff --git a/py-polars/tests/unit/operations/namespaces/list/test_list.py b/py-polars/tests/unit/operations/namespaces/list/test_list.py index 752edb89a9ef..3c9e5fe2f7be 100644 --- a/py-polars/tests/unit/operations/namespaces/list/test_list.py +++ b/py-polars/tests/unit/operations/namespaces/list/test_list.py @@ -49,6 +49,12 @@ def test_list_arr_get() -> None: expected_df = pl.Series("a", [None, None, None], dtype=pl.Int64).to_frame() assert_frame_equal(out_df, expected_df) + # item() + a = pl.Series("a", [[1], [4], [6]]) + expected = pl.Series("a", [1, 4, 6]) + out = a.list.item() + assert_series_equal(out, expected) + a = pl.Series("a", [[1, 2, 3], [4, 5], [6, 7, 8, 9]]) with pytest.raises(ComputeError, match="get index is out of bounds"): diff --git a/py-polars/tests/unit/operations/test_group_by.py b/py-polars/tests/unit/operations/test_group_by.py index 60624d415158..9ac16b6d7a6f 100644 --- a/py-polars/tests/unit/operations/test_group_by.py +++ b/py-polars/tests/unit/operations/test_group_by.py @@ -1032,6 +1032,7 @@ def test_schema_on_agg() -> None: pl.col("b").sum().alias("sum"), pl.col("b").first().alias("first"), pl.col("b").last().alias("last"), + pl.col("b").item().alias("item"), ) expected_schema = { "a": pl.String, @@ -1040,6 +1041,7 @@ def test_schema_on_agg() -> None: "sum": pl.Int64, "first": pl.Int64, "last": pl.Int64, + "item": pl.Int64, } assert result.collect_schema() == expected_schema diff --git a/py-polars/tests/unit/test_cse.py b/py-polars/tests/unit/test_cse.py index 5a0ac685aa39..b89d68ec19d0 100644 --- a/py-polars/tests/unit/test_cse.py +++ b/py-polars/tests/unit/test_cse.py @@ -346,6 +346,7 @@ def test_cse_mixed_window_functions() -> None: pl.col("b").rank().alias("d_rank"), pl.col("b").first().over([pl.col("a")]).alias("b_first"), pl.col("b").last().over([pl.col("a")]).alias("b_last"), + pl.col("b").item().over([pl.col("a")]).alias("b_item"), pl.col("b").shift().alias("b_lag_1"), pl.col("b").shift().alias("b_lead_1"), pl.col("c").cum_sum().alias("c_cumsum"), @@ -363,6 +364,7 @@ def test_cse_mixed_window_functions() -> None: "d_rank": [1.0], "b_first": [1], "b_last": [1], + "b_item": [1], "b_lag_1": [None], "b_lead_1": [None], "c_cumsum": [1], diff --git a/py-polars/tests/unit/test_schema.py b/py-polars/tests/unit/test_schema.py index 8cbdeb945be1..110cece2dfca 100644 --- a/py-polars/tests/unit/test_schema.py +++ b/py-polars/tests/unit/test_schema.py @@ -358,16 +358,16 @@ def test_lazy_agg_to_scalar_schema_19752(lhs: pl.Expr, expr_op: str) -> None: def test_lazy_agg_schema_after_elementwise_19984() -> None: lf = pl.LazyFrame({"a": 1, "b": 1}) - q = lf.group_by("a").agg(pl.col("b").first().fill_null(0)) + q = lf.group_by("a").agg(pl.col("b").item().fill_null(0)) assert q.collect_schema() == q.collect().collect_schema() - q = lf.group_by("a").agg(pl.col("b").first().fill_null(0).fill_null(0)) + q = lf.group_by("a").agg(pl.col("b").item().fill_null(0).fill_null(0)) assert q.collect_schema() == q.collect().collect_schema() - q = lf.group_by("a").agg(pl.col("b").first() + 1) + q = lf.group_by("a").agg(pl.col("b").item() + 1) assert q.collect_schema() == q.collect().collect_schema() - q = lf.group_by("a").agg(1 + pl.col("b").first()) + q = lf.group_by("a").agg(1 + pl.col("b").item()) assert q.collect_schema() == q.collect().collect_schema()