Skip to content

Commit eea184b

Browse files
committed
add take for bit buffer, mask, and primitive vector
Signed-off-by: Connor Tsui <[email protected]>
1 parent 3e430bf commit eea184b

File tree

6 files changed

+254
-0
lines changed

6 files changed

+254
-0
lines changed
Lines changed: 57 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,57 @@
1+
// SPDX-License-Identifier: Apache-2.0
2+
// SPDX-FileCopyrightText: Copyright the Vortex contributors
3+
4+
use vortex_buffer::BitBuffer;
5+
use vortex_buffer::get_bit;
6+
use vortex_dtype::UnsignedPType;
7+
8+
use crate::take::Take;
9+
10+
impl<I: UnsignedPType> Take<[I]> for &BitBuffer {
11+
type Output = BitBuffer;
12+
13+
fn take(self, indices: &[I]) -> BitBuffer {
14+
// For boolean arrays that roughly fit into a single page (at least, on Linux), it's worth
15+
// the overhead to convert to a `Vec<bool>`.
16+
if self.len() <= 4096 {
17+
let bools = self.iter().collect();
18+
take_byte_bool(bools, indices)
19+
} else {
20+
take_bool(self, indices)
21+
}
22+
}
23+
}
24+
25+
// NB: We do NOT implement `impl<I: UnsignedPType> Take<PVector<I>> for &BitBuffer`, specifically
26+
// because there is a very similar implementation on `Mask` that has special logic for working with
27+
// null indices. That logic could also be implemented on `BitBuffer`, but since it is not
28+
// immediately clear what should happen in the case of a null index when taking a `BitBuffer` (do
29+
// you set it to true or false?), we do not implement this at all.
30+
31+
/// # Panics
32+
///
33+
/// Panics if an index is out of bounds.
34+
fn take_byte_bool<I: UnsignedPType>(bools: Vec<bool>, indices: &[I]) -> BitBuffer {
35+
BitBuffer::collect_bool(indices.len(), |idx| {
36+
// SAFETY: We are iterating within the bounds of the `indices` array, so we are always
37+
// within bounds of `indices`.
38+
let bool_idx = unsafe { indices.get_unchecked(idx).as_() };
39+
bools[bool_idx]
40+
})
41+
}
42+
43+
/// # Panics
44+
///
45+
/// Panics if an index is out of bounds.
46+
fn take_bool<I: UnsignedPType>(bools: &BitBuffer, indices: &[I]) -> BitBuffer {
47+
// We dereference to the underlying buffer to avoid incurring an access cost on every index.
48+
let buffer = bools.inner().as_ref();
49+
let offset = bools.offset();
50+
51+
BitBuffer::collect_bool(indices.len(), |idx| {
52+
// SAFETY: We are iterating within the bounds of the `indices` array, so we are always
53+
// within bounds.
54+
let bool_idx = unsafe { indices.get_unchecked(idx).as_() };
55+
get_bit(buffer, offset + bool_idx)
56+
})
57+
}

vortex-compute/src/take/mask.rs

Lines changed: 88 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,88 @@
1+
// SPDX-License-Identifier: Apache-2.0
2+
// SPDX-FileCopyrightText: Copyright the Vortex contributors
3+
4+
use vortex_buffer::BitBuffer;
5+
use vortex_buffer::get_bit;
6+
use vortex_dtype::UnsignedPType;
7+
use vortex_mask::Mask;
8+
use vortex_vector::VectorOps;
9+
use vortex_vector::primitive::PVector;
10+
11+
use crate::take::Take;
12+
13+
impl<I: UnsignedPType> Take<[I]> for &Mask {
14+
type Output = Mask;
15+
16+
fn take(self, indices: &[I]) -> Mask {
17+
match self {
18+
Mask::AllTrue(_) => Mask::AllTrue(indices.len()),
19+
Mask::AllFalse(_) => Mask::AllFalse(indices.len()),
20+
Mask::Values(mask_values) => {
21+
let taken_bit_buffer = mask_values.bit_buffer().take(indices);
22+
Mask::from_buffer(taken_bit_buffer)
23+
}
24+
}
25+
}
26+
}
27+
28+
impl<I: UnsignedPType> Take<PVector<I>> for &Mask {
29+
type Output = Mask;
30+
31+
/// Implementation of take on [`Mask`] that is null-aware.
32+
///
33+
/// If an index is specified as null by the [`PVector`], then the taken mask value is set to
34+
/// `false`.
35+
///
36+
/// This is useful for many of the `take` implementations for vectors.
37+
fn take(self, indices: &PVector<I>) -> Mask {
38+
let indices_validity = indices.validity();
39+
let indices_len = indices.len();
40+
41+
match indices_validity {
42+
Mask::AllTrue(_) => return self.take(indices.elements().as_slice()),
43+
Mask::AllFalse(_) => return Mask::AllFalse(indices_len),
44+
Mask::Values(_) => (),
45+
};
46+
47+
let Mask::Values(indices_validity_values) = indices_validity else {
48+
unreachable!("we just matched on the other cases above");
49+
};
50+
51+
match self {
52+
// Since all the values are true, the only false values will be from the indices.
53+
Mask::AllTrue(_) => Mask::Values(indices_validity_values.clone()),
54+
// Since all the values are already false, the indices nullability wont change anything.
55+
Mask::AllFalse(_) => Mask::AllFalse(indices_len),
56+
Mask::Values(mask_values) => {
57+
// For boolean arrays that roughly fit into a single page (at least, on Linux), it's
58+
// worth the overhead to convert to a `Vec<bool>`.
59+
if self.len() <= 4096 {
60+
let bools = mask_values.bit_buffer().iter().collect();
61+
Mask::from_buffer(take_byte_bool_nullable(bools, indices))
62+
} else {
63+
Mask::from_buffer(take_bool_nullable(mask_values.bit_buffer(), indices))
64+
}
65+
}
66+
}
67+
}
68+
}
69+
70+
fn take_byte_bool_nullable<I: UnsignedPType>(bools: Vec<bool>, indices: &PVector<I>) -> BitBuffer {
71+
BitBuffer::collect_bool(indices.len(), |idx| {
72+
indices
73+
.get(idx)
74+
.is_some_and(|bool_idx| bools[bool_idx.as_()])
75+
})
76+
}
77+
78+
fn take_bool_nullable<I: UnsignedPType>(bools: &BitBuffer, indices: &PVector<I>) -> BitBuffer {
79+
// We dereference to the underlying buffer to avoid incurring an access cost on every index.
80+
let buffer = bools.inner().as_ref();
81+
let offset = bools.offset();
82+
83+
BitBuffer::collect_bool(indices.len(), |idx| {
84+
indices
85+
.get(idx)
86+
.is_some_and(|bool_idx| get_bit(buffer, offset + bool_idx.as_()))
87+
})
88+
}

vortex-compute/src/take/mod.rs

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,8 +3,11 @@
33

44
//! Take function.
55
6+
mod bit_buffer;
67
mod buffer;
8+
mod mask;
79
pub mod slice;
10+
mod vector;
811

912
/// Function for taking based on indices (which can have different representations).
1013
pub trait Take<Indices: ?Sized> {
Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
// SPDX-License-Identifier: Apache-2.0
2+
// SPDX-FileCopyrightText: Copyright the Vortex contributors
3+
4+
mod primitive;
5+
mod pvector;
Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,39 @@
1+
// SPDX-License-Identifier: Apache-2.0
2+
// SPDX-FileCopyrightText: Copyright the Vortex contributors
3+
4+
use vortex_dtype::UnsignedPType;
5+
use vortex_vector::VectorOps;
6+
use vortex_vector::match_each_pvector;
7+
use vortex_vector::primitive::PVector;
8+
use vortex_vector::primitive::PrimitiveVector;
9+
10+
use crate::take::Take;
11+
12+
impl Take<PrimitiveVector> for &PrimitiveVector {
13+
type Output = PrimitiveVector;
14+
15+
fn take(self, indices: &PrimitiveVector) -> PrimitiveVector {
16+
match_each_pvector!(self, |v| { v.take(indices).into() })
17+
}
18+
}
19+
20+
impl<I: UnsignedPType> Take<PVector<I>> for &PrimitiveVector {
21+
type Output = PrimitiveVector;
22+
23+
fn take(self, indices: &PVector<I>) -> PrimitiveVector {
24+
// If all the indices are valid, we can delegate to the slice indices implementation.
25+
if indices.validity().all_true() {
26+
return self.take(indices.elements().as_slice());
27+
}
28+
29+
match_each_pvector!(self, |v| { v.take(indices).into() })
30+
}
31+
}
32+
33+
impl<I: UnsignedPType> Take<[I]> for &PrimitiveVector {
34+
type Output = PrimitiveVector;
35+
36+
fn take(self, indices: &[I]) -> PrimitiveVector {
37+
match_each_pvector!(self, |v| { v.take(indices).into() })
38+
}
39+
}
Lines changed: 62 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,62 @@
1+
// SPDX-License-Identifier: Apache-2.0
2+
// SPDX-FileCopyrightText: Copyright the Vortex contributors
3+
4+
use vortex_dtype::NativePType;
5+
use vortex_dtype::UnsignedPType;
6+
use vortex_vector::VectorOps;
7+
use vortex_vector::match_each_unsigned_pvector;
8+
use vortex_vector::primitive::PVector;
9+
use vortex_vector::primitive::PrimitiveVector;
10+
11+
use crate::take::Take;
12+
13+
impl<T: NativePType> Take<PrimitiveVector> for &PVector<T> {
14+
type Output = PVector<T>;
15+
16+
fn take(self, indices: &PrimitiveVector) -> PVector<T> {
17+
match_each_unsigned_pvector!(indices, |iv| { self.take(iv) })
18+
}
19+
}
20+
21+
impl<T: NativePType, I: UnsignedPType> Take<PVector<I>> for &PVector<T> {
22+
type Output = PVector<T>;
23+
24+
fn take(self, indices: &PVector<I>) -> PVector<T> {
25+
if indices.validity().all_true() {
26+
self.take(indices.elements().as_slice())
27+
} else {
28+
take_nullable(self, indices)
29+
}
30+
}
31+
}
32+
33+
impl<T: NativePType, I: UnsignedPType> Take<[I]> for &PVector<T> {
34+
type Output = PVector<T>;
35+
36+
fn take(self, indices: &[I]) -> PVector<T> {
37+
let taken_elements = self.elements().take(indices);
38+
let taken_validity = self.validity().take(indices);
39+
40+
debug_assert_eq!(taken_elements.len(), taken_validity.len());
41+
42+
// SAFETY: we called take on both components of the vector with the same indices, so the new
43+
// components must have the same length.
44+
unsafe { PVector::new_unchecked(taken_elements, taken_validity) }
45+
}
46+
}
47+
48+
fn take_nullable<T: NativePType, I: UnsignedPType>(
49+
pvector: &PVector<T>,
50+
indices: &PVector<I>,
51+
) -> PVector<T> {
52+
// We ignore nullability when taking the elements since we can let the `Mask` implementation
53+
// determine which elements are null.
54+
let taken_elements = pvector.elements().take(indices.elements().as_slice());
55+
let taken_validity = pvector.validity().take(indices);
56+
57+
debug_assert_eq!(taken_elements.len(), taken_validity.len());
58+
59+
// SAFETY: We used the same indices to take from both components, so they should still have the
60+
// same length.
61+
unsafe { PVector::new_unchecked(taken_elements, taken_validity) }
62+
}

0 commit comments

Comments
 (0)