Skip to content

Commit 8d8163b

Browse files
committed
fix: Optimize memory for groups iter in NotAggregated state
1 parent f544238 commit 8d8163b

File tree

2 files changed

+130
-14
lines changed

2 files changed

+130
-14
lines changed

crates/polars-expr/src/expressions/apply.rs

Lines changed: 35 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -105,7 +105,12 @@ impl ApplyExpr {
105105
&self,
106106
mut ac: AggregationContext<'a>,
107107
) -> PolarsResult<AggregationContext<'a>> {
108+
dbg!("start apply_single_group_aware"); //kdn
109+
dbg!(&self.expr);
110+
// dbg!(&ac);
111+
108112
let s = ac.get_values();
113+
let name = s.name().clone();
109114

110115
#[allow(clippy::nonminimal_bool)]
111116
{
@@ -116,20 +121,6 @@ impl ApplyExpr {
116121
);
117122
}
118123

119-
let name = s.name().clone();
120-
let agg = ac.aggregated();
121-
// Collection of empty list leads to a null dtype. See: #3687.
122-
if agg.is_empty() {
123-
// Create input for the function to determine the output dtype, see #3946.
124-
let agg = agg.list().unwrap();
125-
let input_dtype = agg.inner_dtype();
126-
let input = Column::full_null(name.clone(), 0, input_dtype);
127-
128-
let output = self.eval_and_flatten(&mut [input])?;
129-
let ca = ListChunked::full(name, output.as_materialized_series(), 0);
130-
return self.finish_apply_groups(ac, ca);
131-
}
132-
133124
let f = |opt_s: Option<Series>| match opt_s {
134125
None => Ok(None),
135126
Some(mut s) => {
@@ -144,6 +135,36 @@ impl ApplyExpr {
144135
},
145136
};
146137

138+
// In case of overlapping (rolling) groups, we build groups in a lazy manner to avoid
139+
// memory explosion.
140+
// kdn TODO: TBD - do we want to follow this path for *all* Slice
141+
if matches!(ac.agg_state(), AggState::NotAggregated(_))
142+
&& let GroupsType::Slice { rolling: true, .. } = ac.groups.as_ref().as_ref()
143+
{
144+
let ca: ChunkedArray<_> = ac
145+
.iter_groups_lazy(false)
146+
.map(|opt| opt.map(|s| s.as_ref().clone()))
147+
.map(f)
148+
.collect::<PolarsResult<_>>()?;
149+
150+
return self.finish_apply_groups(ac, ca.with_name(name));
151+
}
152+
153+
// At this point, calling aggregated will not lead to memory explosion.
154+
let agg = ac.aggregated();
155+
156+
// Collection of empty list leads to a null dtype. See: #3687.
157+
if agg.is_empty() {
158+
// Create input for the function to determine the output dtype, see #3946.
159+
let agg = agg.list().unwrap();
160+
let input_dtype = agg.inner_dtype();
161+
let input = Column::full_null(name.clone(), 0, input_dtype);
162+
163+
let output = self.eval_and_flatten(&mut [input])?;
164+
let ca = ListChunked::full(name, output.as_materialized_series(), 0);
165+
return self.finish_apply_groups(ac, ca);
166+
}
167+
147168
let ca: ListChunked = if self.allow_threading {
148169
let dtype = if self.output_field.dtype.is_known() && !self.output_field.dtype.is_null()
149170
{

crates/polars-expr/src/expressions/group_iter.rs

Lines changed: 95 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -73,6 +73,41 @@ impl AggregationContext<'_> {
7373
}
7474
}
7575

76+
impl AggregationContext<'_> {
77+
/// Iterate over groups without greedy aggregation into an AggList.
78+
pub(super) fn iter_groups_lazy(
79+
&mut self,
80+
keep_names: bool,
81+
) -> Box<dyn Iterator<Item = Option<AmortSeries>> + '_> {
82+
match self.agg_state() {
83+
AggState::NotAggregated(_) => {
84+
let groups = self.groups();
85+
let len = groups.len();
86+
let c = self.get_values().rechunk(); //TODO - do we require rechunk?
87+
let name = if keep_names {
88+
c.name().clone()
89+
} else {
90+
PlSmallStr::EMPTY
91+
};
92+
let iter = self.groups().iter();
93+
94+
// Safety:
95+
// kdn TODO
96+
unsafe {
97+
Box::new(NotAggLazyIter::new(
98+
c.as_materialized_series().array_ref(0).clone(),
99+
iter,
100+
len,
101+
c.dtype(),
102+
name,
103+
))
104+
}
105+
},
106+
_ => self.iter_groups(keep_names),
107+
}
108+
}
109+
}
110+
76111
struct LitIter {
77112
len: usize,
78113
offset: usize,
@@ -186,3 +221,63 @@ impl Iterator for FlatIter {
186221
(self.len - self.offset, Some(self.len - self.offset))
187222
}
188223
}
224+
225+
struct NotAggLazyIter<'a, I: Iterator<Item = GroupsIndicator<'a>>> {
226+
array: ArrayRef,
227+
iter: I,
228+
groups_idx: usize,
229+
len: usize,
230+
// AmortSeries referenced that series
231+
#[allow(dead_code)]
232+
series_container: Rc<Series>,
233+
item: AmortSeries,
234+
}
235+
236+
impl<'a, I: Iterator<Item = GroupsIndicator<'a>>> NotAggLazyIter<'a, I> {
237+
/// # Safety
238+
/// kdn TODO
239+
unsafe fn new(
240+
array: ArrayRef,
241+
iter: I,
242+
len: usize,
243+
logical: &DataType,
244+
name: PlSmallStr,
245+
) -> Self {
246+
let series_container = Rc::new(Series::from_chunks_and_dtype_unchecked(
247+
name,
248+
vec![array.clone()],
249+
logical,
250+
));
251+
Self {
252+
array,
253+
iter,
254+
groups_idx: 0,
255+
len,
256+
series_container: series_container.clone(),
257+
item: AmortSeries::new(series_container),
258+
}
259+
}
260+
}
261+
262+
impl<'a, I: Iterator<Item = GroupsIndicator<'a>>> Iterator for NotAggLazyIter<'a, I> {
263+
type Item = Option<AmortSeries>;
264+
265+
fn next(&mut self) -> Option<Self::Item> {
266+
if let Some(g) = self.iter.next() {
267+
match g {
268+
GroupsIndicator::Idx(_) => todo!(), //kdn TODO
269+
GroupsIndicator::Slice(s) => {
270+
let mut arr =
271+
unsafe { self.array.sliced_unchecked(s[0] as usize, s[1] as usize) };
272+
unsafe { self.item.swap(&mut arr) };
273+
Some(Some(self.item.clone()))
274+
},
275+
}
276+
} else {
277+
None
278+
}
279+
}
280+
fn size_hint(&self) -> (usize, Option<usize>) {
281+
(self.len - self.groups_idx, Some(self.len - self.groups_idx))
282+
}
283+
}

0 commit comments

Comments
 (0)