@@ -6,7 +6,7 @@ use vortex_array::compute::Operator;
66use vortex_array:: expr:: { Binary , Literal , Root , VTableExt , lit, root} ;
77use vortex_array:: transform:: { ArrayParentReduceRule , ArrayRuleContext } ;
88use vortex_array:: { ArrayRef , IntoArray } ;
9- use vortex_error:: VortexResult ;
9+ use vortex_error:: { VortexExpect , VortexResult } ;
1010use vortex_scalar:: { PrimitiveScalar , Scalar } ;
1111
1212use super :: compare_common:: { EncodedComparison , encode_for_comparison} ;
@@ -57,29 +57,10 @@ impl ArrayParentReduceRule<ALPVTable, ExprVTable> for ALPExprPushdownRule {
5757 return Ok ( None ) ;
5858 } ;
5959
60- // Get the comparison operator - only handle comparison operators
6160 let operator = binary_view. operator ( ) ;
62- if !matches ! (
63- operator,
64- vortex_array:: expr:: Operator :: Eq
65- | vortex_array:: expr:: Operator :: NotEq
66- | vortex_array:: expr:: Operator :: Lt
67- | vortex_array:: expr:: Operator :: Lte
68- | vortex_array:: expr:: Operator :: Gt
69- | vortex_array:: expr:: Operator :: Gte
70- ) {
71- return Ok ( None ) ;
72- }
7361
74- // Convert to compute operator
75- let compute_op = match operator {
76- vortex_array:: expr:: Operator :: Eq => Operator :: Eq ,
77- vortex_array:: expr:: Operator :: NotEq => Operator :: NotEq ,
78- vortex_array:: expr:: Operator :: Lt => Operator :: Lt ,
79- vortex_array:: expr:: Operator :: Lte => Operator :: Lte ,
80- vortex_array:: expr:: Operator :: Gt => Operator :: Gt ,
81- vortex_array:: expr:: Operator :: Gte => Operator :: Gte ,
82- _ => return Ok ( None ) ,
62+ let Some ( compute_op) = operator. maybe_cmp_operator ( ) else {
63+ return Ok ( None ) ;
8364 } ;
8465
8566 // Check if this is a comparison of root() with a literal
@@ -95,16 +76,13 @@ impl ArrayParentReduceRule<ALPVTable, ExprVTable> for ALPExprPushdownRule {
9576 return Ok ( None ) ;
9677 } ;
9778
98- // Get the literal scalar - literals evaluate to a constant array with one element
9979 let literal_value = literal_expr. as_ :: < Literal > ( ) . data ( ) . clone ( ) ;
10080
101- // Don't optimize nullable comparisons
10281 if literal_value. dtype ( ) . is_nullable ( ) {
10382 return Ok ( None ) ;
10483 }
10584
106- // Convert to primitive scalar
107- let Ok ( pscalar) = PrimitiveScalar :: try_from ( & literal_value) else {
85+ let Some ( pscalar) = literal_value. as_primitive_opt ( ) else {
10886 return Ok ( None ) ;
10987 } ;
11088
@@ -183,7 +161,7 @@ mod tests {
183161
184162 // Apply the optimization
185163 let session = ArraySession :: default ( ) ;
186- crate :: register_alp_rules ( & session) ;
164+ crate :: initialize ( & session) ;
187165 let expr_session = ExprSession :: default ( ) ;
188166 let optimizer = session. optimizer ( ExprOptimizer :: new ( & expr_session) ) ;
189167 let optimized = optimizer. optimize_array ( expr_array. into_array ( ) ) . unwrap ( ) ;
@@ -242,7 +220,7 @@ mod tests {
242220 let expr_array = ExprArray :: new_infer_dtype ( alp. clone ( ) . into_array ( ) , expr) . unwrap ( ) ;
243221
244222 let session = ArraySession :: default ( ) ;
245- crate :: register_alp_rules ( & session) ;
223+ crate :: initialize ( & session) ;
246224 let expr_session = ExprSession :: default ( ) ;
247225 let optimizer = session. optimizer ( ExprOptimizer :: new ( & expr_session) ) ;
248226 let optimized = optimizer. optimize_array ( expr_array. into_array ( ) ) . unwrap ( ) ;
@@ -297,7 +275,7 @@ mod tests {
297275 assert ! ( expr_array. child( ) . is:: <ALPVTable >( ) ) ;
298276
299277 let session = ArraySession :: default ( ) ;
300- crate :: register_alp_rules ( & session) ;
278+ crate :: initialize ( & session) ;
301279 let expr_session = ExprSession :: default ( ) ;
302280 let optimizer = session. optimizer ( ExprOptimizer :: new ( & expr_session) ) ;
303281 let optimized = optimizer. optimize_array ( expr_array. into_array ( ) ) . unwrap ( ) ;
@@ -343,7 +321,7 @@ mod tests {
343321 ExprArray :: new_infer_dtype ( alp. clone ( ) . into_array ( ) , expr. clone ( ) ) . unwrap ( ) ;
344322
345323 let session = ArraySession :: default ( ) ;
346- crate :: register_alp_rules ( & session) ;
324+ crate :: initialize ( & session) ;
347325 let expr_session = ExprSession :: default ( ) ;
348326 let optimizer = session. optimizer ( ExprOptimizer :: new ( & expr_session) ) ;
349327 let optimized = optimizer. optimize_array ( expr_array. into_array ( ) ) . unwrap ( ) ;
@@ -398,7 +376,7 @@ mod tests {
398376 let test_value = 0.06051f32 ;
399377
400378 let session = ArraySession :: default ( ) ;
401- crate :: register_alp_rules ( & session) ;
379+ crate :: initialize ( & session) ;
402380 let expr_session = ExprSession :: default ( ) ;
403381 let expr_optimizer = ExprOptimizer :: new ( & expr_session) ;
404382
0 commit comments