Skip to content

Commit e073a72

Browse files
authored
feat(cubesql): Avoid COUNT(*) pushdown to joined cubes (#9905)
1 parent 585e633 commit e073a72

File tree

5 files changed

+172
-80
lines changed

5 files changed

+172
-80
lines changed

packages/cubejs-schema-compiler/src/adapter/BaseQuery.js

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2134,10 +2134,6 @@ export class BaseQuery {
21342134
if (m.expressionName && !collectedMeasures.length && !m.isMemberExpression) {
21352135
throw new UserError(`Subquery measure ${m.expressionName} should reference at least one member`);
21362136
}
2137-
if (!collectedMeasures.length && m.isMemberExpression && m.query.allCubeNames.length > 1 && m.measureSql() === 'COUNT(*)') {
2138-
const cubeName = m.expressionCubeName ? `\`${m.expressionCubeName}\` ` : '';
2139-
throw new UserError(`The query contains \`COUNT(*)\` expression but cube/view ${cubeName}is missing \`count\` measure`);
2140-
}
21412137

21422138
if (collectedMeasures.length === 0 && m.isMemberExpression) {
21432139
// `m` is member expression measure, but does not reference any other measure

rust/cubesql/cubesql/src/compile/engine/df/wrapper.rs

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3824,6 +3824,21 @@ impl<'ctx, 'mem> CollectMembersVisitor<'ctx, 'mem> {
38243824

38253825
Ok(())
38263826
}
3827+
3828+
fn handle_count_rows(&mut self) -> Result<()> {
3829+
// COUNT(*) references all members in the ungrouped scan node
3830+
for member in &self.push_to_cube_context.ungrouped_scan_node.member_fields {
3831+
match member {
3832+
MemberField::Member(member) => {
3833+
self.used_members.insert(member.member.clone());
3834+
}
3835+
MemberField::Literal(_) => {
3836+
// Do nothing
3837+
}
3838+
}
3839+
}
3840+
Ok(())
3841+
}
38273842
}
38283843

38293844
impl<'ctx, 'mem> ExpressionVisitor for CollectMembersVisitor<'ctx, 'mem> {
@@ -3832,6 +3847,13 @@ impl<'ctx, 'mem> ExpressionVisitor for CollectMembersVisitor<'ctx, 'mem> {
38323847
Expr::Column(ref c) => {
38333848
self.handle_column(c)?;
38343849
}
3850+
Expr::AggregateFunction {
3851+
fun: AggregateFunction::Count,
3852+
args,
3853+
..
3854+
} if args.len() == 1 && matches!(args[0], Expr::Literal(_)) => {
3855+
self.handle_count_rows()?;
3856+
}
38353857
_ => {}
38363858
}
38373859

rust/cubesql/cubesql/src/compile/mod.rs

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17468,4 +17468,40 @@ LIMIT {{ limit }}{% endif %}"#.to_string(),
1746817468
let sql = logical_plan.find_cube_scan_wrapped_sql().wrapped_sql.sql;
1746917469
assert!(sql.contains("DATE_DIFF('day', "));
1747017470
}
17471+
17472+
#[tokio::test]
17473+
async fn test_count_over_joined_cubes() {
17474+
if !Rewriter::sql_push_down_enabled() {
17475+
return;
17476+
}
17477+
init_testing_logger();
17478+
17479+
let query_plan = convert_select_to_query_plan(
17480+
r#"
17481+
SELECT COUNT(*)
17482+
FROM (
17483+
SELECT
17484+
t1.id AS id,
17485+
t2.read AS read
17486+
FROM KibanaSampleDataEcommerce t1
17487+
LEFT JOIN Logs t2 ON t1.__cubeJoinField = t2.__cubeJoinField
17488+
) t
17489+
"#
17490+
.to_string(),
17491+
DatabaseProtocol::PostgreSQL,
17492+
)
17493+
.await;
17494+
17495+
let logical_plan = query_plan.as_logical_plan();
17496+
let sql = logical_plan.find_cube_scan_wrapped_sql().wrapped_sql.sql;
17497+
assert!(sql.contains("COUNT(*)"));
17498+
assert!(sql.contains("KibanaSampleDataEcommerce"));
17499+
assert!(sql.contains("Logs"));
17500+
17501+
let physical_plan = query_plan.as_physical_plan().await.unwrap();
17502+
println!(
17503+
"Physical plan: {}",
17504+
displayable(physical_plan.as_ref()).indent()
17505+
);
17506+
}
1747117507
}

rust/cubesql/cubesql/src/compile/rewrite/rules/members.rs

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2062,7 +2062,7 @@ impl MemberRules {
20622062
) -> impl Fn(&mut CubeEGraph, &mut Subst) -> bool {
20632063
let member_pushdown_replacer_alias_to_cube_var =
20642064
var!(member_pushdown_replacer_alias_to_cube_var);
2065-
let column_var = match column_to_search {
2065+
let column_var = match &column_to_search {
20662066
ColumnToSearch::Var(column_var) => Some(var!(column_var)),
20672067
ColumnToSearch::DefaultCount => None,
20682068
};
@@ -2088,6 +2088,17 @@ impl MemberRules {
20882088
});
20892089

20902090
for alias_to_cube in alias_to_cubes {
2091+
// Do not push down COUNT(*) if there are joined cubes
2092+
if matches!(column_to_search, ColumnToSearch::DefaultCount) {
2093+
let joined_cubes = alias_to_cube
2094+
.iter()
2095+
.map(|(_, cube_name)| cube_name)
2096+
.collect::<HashSet<_>>();
2097+
if joined_cubes.len() > 1 {
2098+
continue;
2099+
}
2100+
}
2101+
20912102
let column_iter = match column_var {
20922103
Some(column_var) => var_iter!(egraph[subst[column_var]], ColumnExprColumn)
20932104
.cloned()

rust/cubesql/cubesql/src/compile/rewrite/rules/wrapper/aggregate.rs

Lines changed: 102 additions & 75 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,8 @@ use crate::{
1717
wrapper_pushdown_replacer, wrapper_replacer_context, AggregateFunctionExprDistinct,
1818
AggregateFunctionExprFun, AggregateUDFExprFun, AliasExprAlias, ColumnExprColumn,
1919
ListType, LiteralExprValue, LogicalPlanData, LogicalPlanLanguage,
20-
WrappedSelectPushToCube, WrapperReplacerContextPushToCube,
20+
WrappedSelectPushToCube, WrapperReplacerContextAliasToCube,
21+
WrapperReplacerContextPushToCube,
2122
},
2223
},
2324
copy_flag,
@@ -26,7 +27,7 @@ use crate::{
2627
};
2728
use datafusion::{logical_plan::Column, scalar::ScalarValue};
2829
use egg::{Subst, Var};
29-
use std::ops::IndexMut;
30+
use std::{collections::HashSet, ops::IndexMut};
3031

3132
impl WrapperRules {
3233
pub fn aggregate_rules(&self, rules: &mut Vec<CubeRewrite>) {
@@ -290,6 +291,7 @@ impl WrapperRules {
290291
"?cube_members",
291292
"?out_measure_expr",
292293
"?out_measure_alias",
294+
"?alias_to_cube",
293295
),
294296
)
295297
},
@@ -1035,97 +1037,119 @@ impl WrapperRules {
10351037
cube_members_var: Var,
10361038
out_expr_var: Var,
10371039
out_alias_var: Var,
1040+
alias_to_cube_var: Var,
10381041
meta: &MetaContext,
10391042
disable_strict_agg_type_match: bool,
10401043
) -> bool {
10411044
let Some(alias) = original_expr_name(egraph, subst[original_expr_var]) else {
10421045
return false;
10431046
};
10441047

1045-
for fun in fun_name_var
1046-
.map(|fun_var| {
1047-
var_iter!(egraph[subst[fun_var]], AggregateFunctionExprFun)
1048-
.map(|fun| Some(fun.clone()))
1049-
.collect()
1050-
})
1051-
.unwrap_or(vec![None])
1048+
for alias_to_cube in var_iter!(
1049+
egraph[subst[alias_to_cube_var]],
1050+
WrapperReplacerContextAliasToCube
1051+
)
1052+
.cloned()
1053+
.collect::<Vec<_>>()
10521054
{
1053-
for distinct in distinct_var
1054-
.map(|distinct_var| {
1055-
var_iter!(egraph[subst[distinct_var]], AggregateFunctionExprDistinct)
1056-
.map(|d| *d)
1055+
// Do not push down COUNT(*) if there are joined cubes
1056+
let is_count_rows = column_var.is_none();
1057+
if is_count_rows {
1058+
let joined_cubes = alias_to_cube
1059+
.iter()
1060+
.map(|(_, cube_name)| cube_name)
1061+
.collect::<HashSet<_>>();
1062+
if joined_cubes.len() > 1 {
1063+
continue;
1064+
}
1065+
}
1066+
1067+
for fun in fun_name_var
1068+
.map(|fun_var| {
1069+
var_iter!(egraph[subst[fun_var]], AggregateFunctionExprFun)
1070+
.map(|fun| Some(fun.clone()))
10571071
.collect()
10581072
})
1059-
.unwrap_or(vec![false])
1073+
.unwrap_or(vec![None])
10601074
{
1061-
let call_agg_type = MemberRules::get_agg_type(fun.as_ref(), distinct);
1075+
for distinct in distinct_var
1076+
.map(|distinct_var| {
1077+
var_iter!(egraph[subst[distinct_var]], AggregateFunctionExprDistinct)
1078+
.map(|d| *d)
1079+
.collect()
1080+
})
1081+
.unwrap_or(vec![false])
1082+
{
1083+
let call_agg_type = MemberRules::get_agg_type(fun.as_ref(), distinct);
10621084

1063-
let column_iter = if let Some(column_var) = column_var {
1064-
var_iter!(egraph[subst[column_var]], ColumnExprColumn)
1065-
.cloned()
1066-
.collect()
1067-
} else {
1068-
vec![Column::from_name(MemberRules::default_count_measure_name())]
1069-
};
1085+
let column_iter = if let Some(column_var) = column_var {
1086+
var_iter!(egraph[subst[column_var]], ColumnExprColumn)
1087+
.cloned()
1088+
.collect()
1089+
} else {
1090+
vec![Column::from_name(MemberRules::default_count_measure_name())]
1091+
};
10701092

1071-
if let Some(member_names_to_expr) = &mut egraph
1072-
.index_mut(subst[cube_members_var])
1073-
.data
1074-
.member_name_to_expr
1075-
{
1076-
for column in column_iter {
1077-
if let Some((&(Some(ref member), _, _), _)) =
1078-
LogicalPlanData::do_find_member_by_alias(
1079-
member_names_to_expr,
1080-
&column.name,
1081-
)
1082-
{
1083-
if let Some(measure) = meta.find_measure_with_name(member) {
1084-
let Some(call_agg_type) = &call_agg_type else {
1085-
// call_agg_type is None, rewrite as is
1086-
Self::insert_regular_measure(
1087-
egraph,
1088-
subst,
1089-
column,
1090-
alias,
1091-
out_expr_var,
1092-
out_alias_var,
1093-
);
1093+
if let Some(member_names_to_expr) = &mut egraph
1094+
.index_mut(subst[cube_members_var])
1095+
.data
1096+
.member_name_to_expr
1097+
{
1098+
for column in column_iter {
1099+
if let Some((&(Some(ref member), _, _), _)) =
1100+
LogicalPlanData::do_find_member_by_alias(
1101+
member_names_to_expr,
1102+
&column.name,
1103+
)
1104+
{
1105+
if let Some(measure) = meta.find_measure_with_name(member) {
1106+
let Some(call_agg_type) = &call_agg_type else {
1107+
// call_agg_type is None, rewrite as is
1108+
Self::insert_regular_measure(
1109+
egraph,
1110+
subst,
1111+
column,
1112+
alias,
1113+
out_expr_var,
1114+
out_alias_var,
1115+
);
10941116

1095-
return true;
1096-
};
1117+
return true;
1118+
};
10971119

1098-
if measure
1099-
.is_same_agg_type(call_agg_type, disable_strict_agg_type_match)
1100-
{
1101-
Self::insert_regular_measure(
1102-
egraph,
1103-
subst,
1104-
column,
1105-
alias,
1106-
out_expr_var,
1107-
out_alias_var,
1108-
);
1120+
if measure.is_same_agg_type(
1121+
call_agg_type,
1122+
disable_strict_agg_type_match,
1123+
) {
1124+
Self::insert_regular_measure(
1125+
egraph,
1126+
subst,
1127+
column,
1128+
alias,
1129+
out_expr_var,
1130+
out_alias_var,
1131+
);
11091132

1110-
return true;
1111-
}
1133+
return true;
1134+
}
11121135

1113-
if measure.allow_replace_agg_type(
1114-
call_agg_type,
1115-
disable_strict_agg_type_match,
1116-
) {
1117-
Self::insert_patch_measure(
1118-
egraph,
1119-
subst,
1120-
column,
1121-
Some(call_agg_type.clone()),
1122-
alias,
1123-
Some(out_expr_var),
1124-
None,
1125-
out_alias_var,
1126-
);
1136+
if measure.allow_replace_agg_type(
1137+
call_agg_type,
1138+
disable_strict_agg_type_match,
1139+
) {
1140+
Self::insert_patch_measure(
1141+
egraph,
1142+
subst,
1143+
column,
1144+
Some(call_agg_type.clone()),
1145+
alias,
1146+
Some(out_expr_var),
1147+
None,
1148+
out_alias_var,
1149+
);
11271150

1128-
return true;
1151+
return true;
1152+
}
11291153
}
11301154
}
11311155
}
@@ -1148,6 +1172,7 @@ impl WrapperRules {
11481172
cube_members_var: &'static str,
11491173
out_expr_var: &'static str,
11501174
out_alias_var: &'static str,
1175+
alias_to_cube_var: &'static str,
11511176
) -> impl Fn(&mut CubeEGraph, &mut Subst) -> bool {
11521177
let original_expr_var = var!(original_expr_var);
11531178
let column_var = column_var.map(|v| var!(v));
@@ -1157,6 +1182,7 @@ impl WrapperRules {
11571182
let cube_members_var = var!(cube_members_var);
11581183
let out_expr_var = var!(out_expr_var);
11591184
let out_alias_var = var!(out_alias_var);
1185+
let alias_to_cube_var = var!(alias_to_cube_var);
11601186
let meta = self.meta_context.clone();
11611187
let disable_strict_agg_type_match = self.config_obj.disable_strict_agg_type_match();
11621188
move |egraph, subst| {
@@ -1170,6 +1196,7 @@ impl WrapperRules {
11701196
cube_members_var,
11711197
out_expr_var,
11721198
out_alias_var,
1199+
alias_to_cube_var,
11731200
&meta,
11741201
disable_strict_agg_type_match,
11751202
)

0 commit comments

Comments
 (0)