Skip to content

Commit 3c37481

Browse files
committed
add take for slices, missing avx2 implementation
Signed-off-by: Connor Tsui <[email protected]>
1 parent fe4c81b commit 3c37481

File tree

9 files changed

+282
-0
lines changed

9 files changed

+282
-0
lines changed

Cargo.lock

Lines changed: 1 addition & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

vortex-buffer/src/buffer.rs

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -412,6 +412,27 @@ impl<T> Buffer<T> {
412412
}
413413
}
414414

415+
/// Cast a `Buffer<T>` into a `Buffer<U>`.
416+
///
417+
/// # Panics
418+
///
419+
/// Panics if the type `U` does not have the same size and alignment as `T`.
420+
pub fn cast_into<U>(self) -> Buffer<U> {
421+
assert_eq!(size_of::<T>(), size_of::<U>(), "Buffer type size mismatch");
422+
assert_eq!(
423+
align_of::<T>(),
424+
align_of::<U>(),
425+
"Buffer type alignment mismatch"
426+
);
427+
428+
Buffer {
429+
bytes: self.bytes,
430+
length: self.length,
431+
alignment: self.alignment,
432+
_marker: PhantomData,
433+
}
434+
}
435+
415436
/// Try to convert self into `BufferMut<T>` if there is only a single strong reference.
416437
pub fn try_into_mut(self) -> Result<BufferMut<T>, Self> {
417438
self.bytes

vortex-compute/Cargo.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@ vortex-vector = { workspace = true }
2929
arrow-array = { workspace = true, optional = true }
3030
arrow-buffer = { workspace = true, optional = true }
3131
arrow-schema = { workspace = true, optional = true }
32+
multiversion = { workspace = true }
3233
num-traits = { workspace = true }
3334

3435
[features]

vortex-compute/src/take/buffer.rs

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,16 @@
1+
// SPDX-License-Identifier: Apache-2.0
2+
// SPDX-FileCopyrightText: Copyright the Vortex contributors
3+
4+
use vortex_buffer::Buffer;
5+
use vortex_dtype::NativePType;
6+
use vortex_dtype::UnsignedPType;
7+
8+
use crate::take::Take;
9+
10+
impl<T: NativePType, I: UnsignedPType> Take<[I]> for &Buffer<T> {
11+
type Output = Buffer<T>;
12+
13+
fn take(self, indices: &[I]) -> Buffer<T> {
14+
self.as_slice().take(indices)
15+
}
16+
}

vortex-compute/src/take/mod.rs

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,9 @@
33

44
//! Take function.
55
6+
mod buffer;
7+
mod slice;
8+
69
/// Function for taking based on indices (which can have different representations).
710
pub trait Take<Indices: ?Sized> {
811
/// The result type after performing the operation.
Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,17 @@
1+
// SPDX-License-Identifier: Apache-2.0
2+
// SPDX-FileCopyrightText: Copyright the Vortex contributors
3+
4+
#![cfg(any(target_arch = "x86_64", target_arch = "x86"))]
5+
6+
use vortex_buffer::Buffer;
7+
use vortex_dtype::NativePType;
8+
use vortex_dtype::UnsignedPType;
9+
10+
#[allow(dead_code, unused_variables, reason = "TODO(connor): Implement this")]
11+
#[inline]
12+
pub fn take_avx2<T: NativePType, I: UnsignedPType>(buffer: &[T], indices: &[I]) -> Buffer<T> {
13+
todo!(
14+
"TODO(connor): Migrate the implementation in \
15+
vortex-array/src/arrays/primitive/compute/take/avx2.rs"
16+
)
17+
}
Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,41 @@
1+
// SPDX-License-Identifier: Apache-2.0
2+
// SPDX-FileCopyrightText: Copyright the Vortex contributors
3+
4+
use vortex_buffer::Buffer;
5+
use vortex_dtype::NativePType;
6+
use vortex_dtype::UnsignedPType;
7+
8+
use crate::take::Take;
9+
10+
mod avx2;
11+
mod portable;
12+
13+
/// Specialized implementation for non-nullable indices.
14+
impl<T: NativePType, I: UnsignedPType> Take<[I]> for &[T] {
15+
type Output = Buffer<T>;
16+
17+
fn take(self, indices: &[I]) -> Buffer<T> {
18+
#[cfg(vortex_nightly)]
19+
{
20+
return portable::take_portable(self, indices);
21+
}
22+
23+
#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
24+
{
25+
if is_x86_feature_detected!("avx2") {
26+
return avx2::take_avx2(self, indices);
27+
}
28+
}
29+
30+
take_scalar(self, indices)
31+
}
32+
}
33+
34+
#[allow(
35+
unused,
36+
reason = "Compiler may see this as unused based on enabled features"
37+
)]
38+
#[inline]
39+
fn take_scalar<T: NativePType, I: UnsignedPType>(buffer: &[T], indices: &[I]) -> Buffer<T> {
40+
indices.iter().map(|idx| buffer[idx.as_()]).collect()
41+
}
Lines changed: 110 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,110 @@
1+
// SPDX-License-Identifier: Apache-2.0
2+
// SPDX-FileCopyrightText: Copyright the Vortex contributors
3+
4+
#![cfg(vortex_nightly)]
5+
6+
use std::mem::MaybeUninit;
7+
use std::mem::transmute;
8+
use std::simd;
9+
use std::simd::num::SimdUint;
10+
11+
use multiversion::multiversion;
12+
use vortex_buffer::Alignment;
13+
use vortex_buffer::Buffer;
14+
use vortex_buffer::BufferMut;
15+
use vortex_dtype::NativePType;
16+
use vortex_dtype::PType;
17+
use vortex_dtype::UnsignedPType;
18+
19+
#[inline]
20+
pub fn take_portable<T, I>(buffer: &[T], indices: &[I]) -> Buffer<T>
21+
where
22+
T: NativePType + simd::SimdElement,
23+
I: UnsignedPType + simd::SimdElement,
24+
{
25+
if T::PTYPE == PType::F16 {
26+
// Since Rust does not actually support 16-bit floats, we first reinterpret the data as
27+
// `u16` integers.
28+
let u16_slice: &[u16] =
29+
unsafe { std::slice::from_raw_parts(buffer.as_ptr() as *const u16, buffer.len()) };
30+
31+
let taken_u16 = take_portable_simd::<u16, I, SIMD_WIDTH>(u16_slice, indices);
32+
let taken_f16 = taken_u16.cast_into::<T>();
33+
34+
taken_f16
35+
} else {
36+
take_portable_simd::<T, I, SIMD_WIDTH>(buffer, indices)
37+
}
38+
}
39+
40+
/// Takes elements from an array using SIMD indexing.
41+
///
42+
/// Performs a gather operation that takes values at specified indices and returns them in a new
43+
/// buffer. Uses SIMD instructions to process `LANE_COUNT` indices in parallel.
44+
///
45+
/// Returns a `Buffer<T>` where each element corresponds to `values[indices[i]]`.
46+
#[multiversion(targets("x86_64+avx2", "x86_64+avx", "aarch64+neon"))]
47+
fn take_portable_simd<T, I, const LANE_COUNT: usize>(values: &[T], indices: &[I]) -> Buffer<T>
48+
where
49+
T: NativePType + simd::SimdElement,
50+
I: UnsignedPType + simd::SimdElement,
51+
simd::LaneCount<LANE_COUNT>: simd::SupportedLaneCount,
52+
simd::Simd<I, LANE_COUNT>: SimdUint<Cast<usize> = simd::Simd<usize, LANE_COUNT>>,
53+
{
54+
let indices_len = indices.len();
55+
56+
let mut buffer = BufferMut::<T>::with_capacity_aligned(
57+
indices_len,
58+
Alignment::of::<simd::Simd<T, LANE_COUNT>>(),
59+
);
60+
61+
let buf_slice = buffer.spare_capacity_mut();
62+
63+
for chunk_idx in 0..(indices_len / LANE_COUNT) {
64+
let offset = chunk_idx * LANE_COUNT;
65+
let mask = simd::Mask::from_bitmask(u64::MAX);
66+
let codes_chunk = simd::Simd::<I, LANE_COUNT>::from_slice(&indices[offset..]);
67+
68+
let selection = simd::Simd::gather_select(
69+
values,
70+
mask,
71+
codes_chunk.cast::<usize>(),
72+
simd::Simd::<T, LANE_COUNT>::default(),
73+
);
74+
75+
unsafe {
76+
selection.store_select_unchecked(
77+
transmute::<&mut [MaybeUninit<T>], &mut [T]>(&mut buf_slice[offset..][..64]),
78+
mask.cast(),
79+
);
80+
}
81+
}
82+
83+
for idx in ((indices_len / LANE_COUNT) * LANE_COUNT)..indices_len {
84+
unsafe {
85+
buf_slice
86+
.get_unchecked_mut(idx)
87+
.write(values[indices[idx].as_()]);
88+
}
89+
}
90+
91+
unsafe {
92+
buffer.set_len(indices_len);
93+
}
94+
95+
buffer.freeze()
96+
}
97+
98+
#[cfg(test)]
99+
mod tests {
100+
use super::take_portable_simd;
101+
102+
#[test]
103+
fn test_take_out_of_bounds() {
104+
let indices = vec![2_000_000u32; 64];
105+
let values = vec![1i32];
106+
107+
let result = take_portable_simd::<i32, u32, 64>(&values, &indices);
108+
assert_eq!(result.as_slice(), [0i32; 64]);
109+
}
110+
}

vortex-vector/src/primitive/macros.rs

Lines changed: 72 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -97,6 +97,42 @@ macro_rules! match_each_integer_pvector {
9797
}};
9898
}
9999

100+
/// Matches on all unsigned type variants of [`PrimitiveVector`] and executes the same code for each
101+
/// of the unsigned variant branches.
102+
///
103+
/// This macro eliminates repetitive match statements when implementing operations that need to work
104+
/// uniformly across all unsigned type variants (`U8`, `U16`, `U32`, `U64`).
105+
///
106+
/// See [`match_each_pvector`] for similar usage.
107+
///
108+
/// [`PrimitiveVector`]: crate::primitive::PrimitiveVector
109+
///
110+
/// # Panics
111+
///
112+
/// Panics if the vector passed in to the macro is not an unsigned vector variant.
113+
#[macro_export]
114+
macro_rules! match_each_unsigned_pvector {
115+
($self:expr, | $vec:ident | $body:block) => {{
116+
match $self {
117+
$crate::primitive::PrimitiveVector::U8($vec) => $body,
118+
$crate::primitive::PrimitiveVector::U16($vec) => $body,
119+
$crate::primitive::PrimitiveVector::U32($vec) => $body,
120+
$crate::primitive::PrimitiveVector::U64($vec) => $body,
121+
$crate::primitive::PrimitiveVector::I8(_)
122+
| $crate::primitive::PrimitiveVector::I16(_)
123+
| $crate::primitive::PrimitiveVector::I32(_)
124+
| $crate::primitive::PrimitiveVector::I64(_)
125+
| $crate::primitive::PrimitiveVector::F16(_)
126+
| $crate::primitive::PrimitiveVector::F32(_)
127+
| $crate::primitive::PrimitiveVector::F64(_) => {
128+
::vortex_error::vortex_panic!(
129+
"Tried to match a non-unsigned vector in an unsigned match statement"
130+
)
131+
}
132+
}
133+
}};
134+
}
135+
100136
/// Matches on all primitive type variants of [`PrimitiveVectorMut`] and executes the same code
101137
/// for each variant branch.
102138
///
@@ -184,3 +220,39 @@ macro_rules! match_each_integer_pvector_mut {
184220
}
185221
}};
186222
}
223+
224+
/// Matches on all unsigned type variants of [`PrimitiveVectorMut`] and executes the same code for
225+
/// each of the unsigned variant branches.
226+
///
227+
/// This macro eliminates repetitive match statements when implementing operations that need to work
228+
/// uniformly across all unsigned type variants (`U8`, `U16`, `U32`, `U64`).
229+
///
230+
/// See [`match_each_pvector_mut`] for similar usage.
231+
///
232+
/// [`PrimitiveVectorMut`]: crate::primitive::PrimitiveVectorMut
233+
///
234+
/// # Panics
235+
///
236+
/// Panics if the vector passed in to the macro is not an unsigned vector variant.
237+
#[macro_export]
238+
macro_rules! match_each_unsigned_pvector_mut {
239+
($self:expr, | $vec:ident | $body:block) => {{
240+
match $self {
241+
$crate::primitive::PrimitiveVectorMut::U8($vec) => $body,
242+
$crate::primitive::PrimitiveVectorMut::U16($vec) => $body,
243+
$crate::primitive::PrimitiveVectorMut::U32($vec) => $body,
244+
$crate::primitive::PrimitiveVectorMut::U64($vec) => $body,
245+
$crate::primitive::PrimitiveVectorMut::I8(_)
246+
| $crate::primitive::PrimitiveVectorMut::I16(_)
247+
| $crate::primitive::PrimitiveVectorMut::I32(_)
248+
| $crate::primitive::PrimitiveVectorMut::I64(_)
249+
| $crate::primitive::PrimitiveVectorMut::F16(_)
250+
| $crate::primitive::PrimitiveVectorMut::F32(_)
251+
| $crate::primitive::PrimitiveVectorMut::F64(_) => {
252+
::vortex_error::vortex_panic!(
253+
"Tried to match a non-unsigned mutable vector in an unsigned match statement"
254+
)
255+
}
256+
}
257+
}};
258+
}

0 commit comments

Comments
 (0)