@@ -17,7 +17,8 @@ use crate::{
17
17
wrapper_pushdown_replacer, wrapper_replacer_context, AggregateFunctionExprDistinct ,
18
18
AggregateFunctionExprFun , AggregateUDFExprFun , AliasExprAlias , ColumnExprColumn ,
19
19
ListType , LiteralExprValue , LogicalPlanData , LogicalPlanLanguage ,
20
- WrappedSelectPushToCube , WrapperReplacerContextPushToCube ,
20
+ WrappedSelectPushToCube , WrapperReplacerContextAliasToCube ,
21
+ WrapperReplacerContextPushToCube ,
21
22
} ,
22
23
} ,
23
24
copy_flag,
@@ -26,7 +27,7 @@ use crate::{
26
27
} ;
27
28
use datafusion:: { logical_plan:: Column , scalar:: ScalarValue } ;
28
29
use egg:: { Subst , Var } ;
29
- use std:: ops:: IndexMut ;
30
+ use std:: { collections :: HashSet , ops:: IndexMut } ;
30
31
31
32
impl WrapperRules {
32
33
pub fn aggregate_rules ( & self , rules : & mut Vec < CubeRewrite > ) {
@@ -290,6 +291,7 @@ impl WrapperRules {
290
291
"?cube_members" ,
291
292
"?out_measure_expr" ,
292
293
"?out_measure_alias" ,
294
+ "?alias_to_cube" ,
293
295
) ,
294
296
)
295
297
} ,
@@ -1035,97 +1037,119 @@ impl WrapperRules {
1035
1037
cube_members_var : Var ,
1036
1038
out_expr_var : Var ,
1037
1039
out_alias_var : Var ,
1040
+ alias_to_cube_var : Var ,
1038
1041
meta : & MetaContext ,
1039
1042
disable_strict_agg_type_match : bool ,
1040
1043
) -> bool {
1041
1044
let Some ( alias) = original_expr_name ( egraph, subst[ original_expr_var] ) else {
1042
1045
return false ;
1043
1046
} ;
1044
1047
1045
- for fun in fun_name_var
1046
- . map ( |fun_var| {
1047
- var_iter ! ( egraph[ subst[ fun_var] ] , AggregateFunctionExprFun )
1048
- . map ( |fun| Some ( fun. clone ( ) ) )
1049
- . collect ( )
1050
- } )
1051
- . unwrap_or ( vec ! [ None ] )
1048
+ for alias_to_cube in var_iter ! (
1049
+ egraph[ subst[ alias_to_cube_var] ] ,
1050
+ WrapperReplacerContextAliasToCube
1051
+ )
1052
+ . cloned ( )
1053
+ . collect :: < Vec < _ > > ( )
1052
1054
{
1053
- for distinct in distinct_var
1054
- . map ( |distinct_var| {
1055
- var_iter ! ( egraph[ subst[ distinct_var] ] , AggregateFunctionExprDistinct )
1056
- . map ( |d| * d)
1055
+ // Do not push down COUNT(*) if there are joined cubes
1056
+ let is_count_rows = column_var. is_none ( ) ;
1057
+ if is_count_rows {
1058
+ let joined_cubes = alias_to_cube
1059
+ . iter ( )
1060
+ . map ( |( _, cube_name) | cube_name)
1061
+ . collect :: < HashSet < _ > > ( ) ;
1062
+ if joined_cubes. len ( ) > 1 {
1063
+ continue ;
1064
+ }
1065
+ }
1066
+
1067
+ for fun in fun_name_var
1068
+ . map ( |fun_var| {
1069
+ var_iter ! ( egraph[ subst[ fun_var] ] , AggregateFunctionExprFun )
1070
+ . map ( |fun| Some ( fun. clone ( ) ) )
1057
1071
. collect ( )
1058
1072
} )
1059
- . unwrap_or ( vec ! [ false ] )
1073
+ . unwrap_or ( vec ! [ None ] )
1060
1074
{
1061
- let call_agg_type = MemberRules :: get_agg_type ( fun. as_ref ( ) , distinct) ;
1075
+ for distinct in distinct_var
1076
+ . map ( |distinct_var| {
1077
+ var_iter ! ( egraph[ subst[ distinct_var] ] , AggregateFunctionExprDistinct )
1078
+ . map ( |d| * d)
1079
+ . collect ( )
1080
+ } )
1081
+ . unwrap_or ( vec ! [ false ] )
1082
+ {
1083
+ let call_agg_type = MemberRules :: get_agg_type ( fun. as_ref ( ) , distinct) ;
1062
1084
1063
- let column_iter = if let Some ( column_var) = column_var {
1064
- var_iter ! ( egraph[ subst[ column_var] ] , ColumnExprColumn )
1065
- . cloned ( )
1066
- . collect ( )
1067
- } else {
1068
- vec ! [ Column :: from_name( MemberRules :: default_count_measure_name( ) ) ]
1069
- } ;
1085
+ let column_iter = if let Some ( column_var) = column_var {
1086
+ var_iter ! ( egraph[ subst[ column_var] ] , ColumnExprColumn )
1087
+ . cloned ( )
1088
+ . collect ( )
1089
+ } else {
1090
+ vec ! [ Column :: from_name( MemberRules :: default_count_measure_name( ) ) ]
1091
+ } ;
1070
1092
1071
- if let Some ( member_names_to_expr) = & mut egraph
1072
- . index_mut ( subst[ cube_members_var] )
1073
- . data
1074
- . member_name_to_expr
1075
- {
1076
- for column in column_iter {
1077
- if let Some ( ( & ( Some ( ref member) , _, _) , _) ) =
1078
- LogicalPlanData :: do_find_member_by_alias (
1079
- member_names_to_expr,
1080
- & column. name ,
1081
- )
1082
- {
1083
- if let Some ( measure) = meta. find_measure_with_name ( member) {
1084
- let Some ( call_agg_type) = & call_agg_type else {
1085
- // call_agg_type is None, rewrite as is
1086
- Self :: insert_regular_measure (
1087
- egraph,
1088
- subst,
1089
- column,
1090
- alias,
1091
- out_expr_var,
1092
- out_alias_var,
1093
- ) ;
1093
+ if let Some ( member_names_to_expr) = & mut egraph
1094
+ . index_mut ( subst[ cube_members_var] )
1095
+ . data
1096
+ . member_name_to_expr
1097
+ {
1098
+ for column in column_iter {
1099
+ if let Some ( ( & ( Some ( ref member) , _, _) , _) ) =
1100
+ LogicalPlanData :: do_find_member_by_alias (
1101
+ member_names_to_expr,
1102
+ & column. name ,
1103
+ )
1104
+ {
1105
+ if let Some ( measure) = meta. find_measure_with_name ( member) {
1106
+ let Some ( call_agg_type) = & call_agg_type else {
1107
+ // call_agg_type is None, rewrite as is
1108
+ Self :: insert_regular_measure (
1109
+ egraph,
1110
+ subst,
1111
+ column,
1112
+ alias,
1113
+ out_expr_var,
1114
+ out_alias_var,
1115
+ ) ;
1094
1116
1095
- return true ;
1096
- } ;
1117
+ return true ;
1118
+ } ;
1097
1119
1098
- if measure
1099
- . is_same_agg_type ( call_agg_type, disable_strict_agg_type_match)
1100
- {
1101
- Self :: insert_regular_measure (
1102
- egraph,
1103
- subst,
1104
- column,
1105
- alias,
1106
- out_expr_var,
1107
- out_alias_var,
1108
- ) ;
1120
+ if measure. is_same_agg_type (
1121
+ call_agg_type,
1122
+ disable_strict_agg_type_match,
1123
+ ) {
1124
+ Self :: insert_regular_measure (
1125
+ egraph,
1126
+ subst,
1127
+ column,
1128
+ alias,
1129
+ out_expr_var,
1130
+ out_alias_var,
1131
+ ) ;
1109
1132
1110
- return true ;
1111
- }
1133
+ return true ;
1134
+ }
1112
1135
1113
- if measure. allow_replace_agg_type (
1114
- call_agg_type,
1115
- disable_strict_agg_type_match,
1116
- ) {
1117
- Self :: insert_patch_measure (
1118
- egraph,
1119
- subst,
1120
- column,
1121
- Some ( call_agg_type. clone ( ) ) ,
1122
- alias,
1123
- Some ( out_expr_var) ,
1124
- None ,
1125
- out_alias_var,
1126
- ) ;
1136
+ if measure. allow_replace_agg_type (
1137
+ call_agg_type,
1138
+ disable_strict_agg_type_match,
1139
+ ) {
1140
+ Self :: insert_patch_measure (
1141
+ egraph,
1142
+ subst,
1143
+ column,
1144
+ Some ( call_agg_type. clone ( ) ) ,
1145
+ alias,
1146
+ Some ( out_expr_var) ,
1147
+ None ,
1148
+ out_alias_var,
1149
+ ) ;
1127
1150
1128
- return true ;
1151
+ return true ;
1152
+ }
1129
1153
}
1130
1154
}
1131
1155
}
@@ -1148,6 +1172,7 @@ impl WrapperRules {
1148
1172
cube_members_var : & ' static str ,
1149
1173
out_expr_var : & ' static str ,
1150
1174
out_alias_var : & ' static str ,
1175
+ alias_to_cube_var : & ' static str ,
1151
1176
) -> impl Fn ( & mut CubeEGraph , & mut Subst ) -> bool {
1152
1177
let original_expr_var = var ! ( original_expr_var) ;
1153
1178
let column_var = column_var. map ( |v| var ! ( v) ) ;
@@ -1157,6 +1182,7 @@ impl WrapperRules {
1157
1182
let cube_members_var = var ! ( cube_members_var) ;
1158
1183
let out_expr_var = var ! ( out_expr_var) ;
1159
1184
let out_alias_var = var ! ( out_alias_var) ;
1185
+ let alias_to_cube_var = var ! ( alias_to_cube_var) ;
1160
1186
let meta = self . meta_context . clone ( ) ;
1161
1187
let disable_strict_agg_type_match = self . config_obj . disable_strict_agg_type_match ( ) ;
1162
1188
move |egraph, subst| {
@@ -1170,6 +1196,7 @@ impl WrapperRules {
1170
1196
cube_members_var,
1171
1197
out_expr_var,
1172
1198
out_alias_var,
1199
+ alias_to_cube_var,
1173
1200
& meta,
1174
1201
disable_strict_agg_type_match,
1175
1202
)
0 commit comments