Skip to content

Commit f6f04e7

Browse files
committed
wip
Signed-off-by: Joe Isaacs <[email protected]>
1 parent e3bea37 commit f6f04e7

File tree

3 files changed

+15
-35
lines changed

3 files changed

+15
-35
lines changed

encodings/alp/src/alp/compute/expr_pushdown.rs

Lines changed: 9 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@ use vortex_array::compute::Operator;
66
use vortex_array::expr::{Binary, Literal, Root, VTableExt, lit, root};
77
use vortex_array::transform::{ArrayParentReduceRule, ArrayRuleContext};
88
use vortex_array::{ArrayRef, IntoArray};
9-
use vortex_error::VortexResult;
9+
use vortex_error::{VortexExpect, VortexResult};
1010
use vortex_scalar::{PrimitiveScalar, Scalar};
1111

1212
use super::compare_common::{EncodedComparison, encode_for_comparison};
@@ -57,29 +57,10 @@ impl ArrayParentReduceRule<ALPVTable, ExprVTable> for ALPExprPushdownRule {
5757
return Ok(None);
5858
};
5959

60-
// Get the comparison operator - only handle comparison operators
6160
let operator = binary_view.operator();
62-
if !matches!(
63-
operator,
64-
vortex_array::expr::Operator::Eq
65-
| vortex_array::expr::Operator::NotEq
66-
| vortex_array::expr::Operator::Lt
67-
| vortex_array::expr::Operator::Lte
68-
| vortex_array::expr::Operator::Gt
69-
| vortex_array::expr::Operator::Gte
70-
) {
71-
return Ok(None);
72-
}
7361

74-
// Convert to compute operator
75-
let compute_op = match operator {
76-
vortex_array::expr::Operator::Eq => Operator::Eq,
77-
vortex_array::expr::Operator::NotEq => Operator::NotEq,
78-
vortex_array::expr::Operator::Lt => Operator::Lt,
79-
vortex_array::expr::Operator::Lte => Operator::Lte,
80-
vortex_array::expr::Operator::Gt => Operator::Gt,
81-
vortex_array::expr::Operator::Gte => Operator::Gte,
82-
_ => return Ok(None),
62+
let Some(compute_op) = operator.maybe_cmp_operator() else {
63+
return Ok(None);
8364
};
8465

8566
// Check if this is a comparison of root() with a literal
@@ -95,16 +76,13 @@ impl ArrayParentReduceRule<ALPVTable, ExprVTable> for ALPExprPushdownRule {
9576
return Ok(None);
9677
};
9778

98-
// Get the literal scalar - literals evaluate to a constant array with one element
9979
let literal_value = literal_expr.as_::<Literal>().data().clone();
10080

101-
// Don't optimize nullable comparisons
10281
if literal_value.dtype().is_nullable() {
10382
return Ok(None);
10483
}
10584

106-
// Convert to primitive scalar
107-
let Ok(pscalar) = PrimitiveScalar::try_from(&literal_value) else {
85+
let Some(pscalar) = literal_value.as_primitive_opt() else {
10886
return Ok(None);
10987
};
11088

@@ -183,7 +161,7 @@ mod tests {
183161

184162
// Apply the optimization
185163
let session = ArraySession::default();
186-
crate::register_alp_rules(&session);
164+
crate::initialize(&session);
187165
let expr_session = ExprSession::default();
188166
let optimizer = session.optimizer(ExprOptimizer::new(&expr_session));
189167
let optimized = optimizer.optimize_array(expr_array.into_array()).unwrap();
@@ -242,7 +220,7 @@ mod tests {
242220
let expr_array = ExprArray::new_infer_dtype(alp.clone().into_array(), expr).unwrap();
243221

244222
let session = ArraySession::default();
245-
crate::register_alp_rules(&session);
223+
crate::initialize(&session);
246224
let expr_session = ExprSession::default();
247225
let optimizer = session.optimizer(ExprOptimizer::new(&expr_session));
248226
let optimized = optimizer.optimize_array(expr_array.into_array()).unwrap();
@@ -297,7 +275,7 @@ mod tests {
297275
assert!(expr_array.child().is::<ALPVTable>());
298276

299277
let session = ArraySession::default();
300-
crate::register_alp_rules(&session);
278+
crate::initialize(&session);
301279
let expr_session = ExprSession::default();
302280
let optimizer = session.optimizer(ExprOptimizer::new(&expr_session));
303281
let optimized = optimizer.optimize_array(expr_array.into_array()).unwrap();
@@ -343,7 +321,7 @@ mod tests {
343321
ExprArray::new_infer_dtype(alp.clone().into_array(), expr.clone()).unwrap();
344322

345323
let session = ArraySession::default();
346-
crate::register_alp_rules(&session);
324+
crate::initialize(&session);
347325
let expr_session = ExprSession::default();
348326
let optimizer = session.optimizer(ExprOptimizer::new(&expr_session));
349327
let optimized = optimizer.optimize_array(expr_array.into_array()).unwrap();
@@ -398,7 +376,7 @@ mod tests {
398376
let test_value = 0.06051f32;
399377

400378
let session = ArraySession::default();
401-
crate::register_alp_rules(&session);
379+
crate::initialize(&session);
402380
let expr_session = ExprSession::default();
403381
let expr_optimizer = ExprOptimizer::new(&expr_session);
404382

encodings/alp/src/lib.rs

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,14 +18,17 @@
1818
1919
pub use alp::*;
2020
pub use alp_rd::*;
21+
use vortex_array::EncodingRef;
2122

2223
mod alp;
2324
mod alp_rd;
2425

2526
/// Register ALP optimization rules with an ArraySession.
26-
pub fn register_alp_rules(session: &vortex_array::ArraySession) {
27+
pub fn initialize(session: &vortex_array::ArraySession) {
2728
use vortex_array::arrays::{ExprEncoding, ExprVTable};
2829

30+
session.register(EncodingRef::new_ref(ALPEncoding.as_ref()));
31+
2932
// Register the comparison pushdown rule for ALP arrays wrapped in ExprArray
3033
session.register_parent_rule::<ALPVTable, ExprVTable, _>(
3134
&ALPEncoding,

vortex-file/src/lib.rs

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -106,7 +106,7 @@ pub use footer::*;
106106
pub use forever_constant::*;
107107
pub use open::*;
108108
pub use strategy::*;
109-
use vortex_alp::{ALPEncoding, ALPRDEncoding, register_alp_rules};
109+
use vortex_alp::{ALPEncoding, ALPRDEncoding};
110110
use vortex_array::arrays::DictEncoding;
111111
use vortex_array::{ArraySessionExt, EncodingRef};
112112
use vortex_bytebool::ByteBoolEncoding;
@@ -160,7 +160,6 @@ mod forever_constant {
160160
/// Vortex "Editions" that may support different sets of encodings.
161161
pub fn register_default_encodings(session: &VortexSession) {
162162
session.arrays().register_many([
163-
EncodingRef::new_ref(ALPEncoding.as_ref()),
164163
EncodingRef::new_ref(ALPRDEncoding.as_ref()),
165164
EncodingRef::new_ref(BitPackedEncoding.as_ref()),
166165
EncodingRef::new_ref(ByteBoolEncoding.as_ref()),
@@ -180,5 +179,5 @@ pub fn register_default_encodings(session: &VortexSession) {
180179
EncodingRef::new_ref(vortex_zstd::ZstdEncoding.as_ref()),
181180
]);
182181

183-
register_alp_rules(&session.arrays())
182+
vortex_alp::initialize(&session.arrays())
184183
}

0 commit comments

Comments
 (0)