Skip to content

Commit f3ac42e

Browse files
committed
call into vortex-compute instead of duplication
Signed-off-by: Connor Tsui <[email protected]>
1 parent 3c37481 commit f3ac42e

File tree

5 files changed

+17
-74
lines changed

5 files changed

+17
-74
lines changed

vortex-array/src/arrays/primitive/compute/take/portable.rs

Lines changed: 5 additions & 70 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@ use num_traits::AsPrimitive;
1818
use vortex_buffer::Alignment;
1919
use vortex_buffer::Buffer;
2020
use vortex_buffer::BufferMut;
21+
use vortex_compute::take::slice::portable;
2122
use vortex_dtype::NativePType;
2223
use vortex_dtype::PType;
2324
use vortex_dtype::match_each_native_simd_ptype;
@@ -46,9 +47,9 @@ impl TakeImpl for TakeKernelPortableSimd {
4647
if array.ptype() == PType::F16 {
4748
// Special handling for f16 to treat as opaque u16
4849
let decoded = match_each_unsigned_integer_ptype!(unsigned_indices.ptype(), |C| {
49-
take_portable_simd::<C, u16, SIMD_WIDTH>(
50-
unsigned_indices.as_slice(),
50+
portable::take_portable_simd::<u16, C, SIMD_WIDTH>(
5151
array.reinterpret_cast(PType::U16).as_slice(),
52+
unsigned_indices.as_slice(),
5253
)
5354
});
5455
Ok(PrimitiveArray::new(decoded, validity)
@@ -57,9 +58,9 @@ impl TakeImpl for TakeKernelPortableSimd {
5758
} else {
5859
match_each_unsigned_integer_ptype!(unsigned_indices.ptype(), |C| {
5960
match_each_native_simd_ptype!(array.ptype(), |V| {
60-
let decoded = take_portable_simd::<C, V, SIMD_WIDTH>(
61-
unsigned_indices.as_slice(),
61+
let decoded = portable::take_portable_simd::<V, C, SIMD_WIDTH>(
6262
array.as_slice(),
63+
unsigned_indices.as_slice(),
6364
);
6465
Ok(PrimitiveArray::new(decoded, validity).into_array())
6566
})
@@ -68,72 +69,6 @@ impl TakeImpl for TakeKernelPortableSimd {
6869
}
6970
}
7071

71-
/// Takes elements from an array using SIMD indexing.
72-
///
73-
/// # Type Parameters
74-
/// * `C` - Index type
75-
/// * `V` - Value type
76-
/// * `LANE_COUNT` - Number of SIMD lanes to process in parallel
77-
///
78-
/// # Parameters
79-
/// * `indices` - Indices to gather values from
80-
/// * `values` - Source values to index
81-
///
82-
/// # Returns
83-
/// A `PrimitiveArray` containing the gathered values where each index has been replaced with
84-
/// the corresponding value from the source array.
85-
#[multiversion(targets("x86_64+avx2", "x86_64+avx", "aarch64+neon"))]
86-
fn take_portable_simd<I, V, const LANE_COUNT: usize>(indices: &[I], values: &[V]) -> Buffer<V>
87-
where
88-
I: simd::SimdElement + AsPrimitive<usize>,
89-
V: simd::SimdElement + NativePType,
90-
simd::LaneCount<LANE_COUNT>: simd::SupportedLaneCount,
91-
simd::Simd<I, LANE_COUNT>: SimdUint<Cast<usize> = simd::Simd<usize, LANE_COUNT>>,
92-
{
93-
let indices_len = indices.len();
94-
95-
let mut buffer = BufferMut::<V>::with_capacity_aligned(
96-
indices_len,
97-
Alignment::of::<simd::Simd<V, LANE_COUNT>>(),
98-
);
99-
100-
let buf_slice = buffer.spare_capacity_mut();
101-
102-
for chunk_idx in 0..(indices_len / LANE_COUNT) {
103-
let offset = chunk_idx * LANE_COUNT;
104-
let mask = simd::Mask::from_bitmask(u64::MAX);
105-
let codes_chunk = simd::Simd::<I, LANE_COUNT>::from_slice(&indices[offset..]);
106-
107-
let selection = simd::Simd::gather_select(
108-
values,
109-
mask,
110-
codes_chunk.cast::<usize>(),
111-
simd::Simd::<V, LANE_COUNT>::default(),
112-
);
113-
114-
unsafe {
115-
selection.store_select_unchecked(
116-
transmute::<&mut [MaybeUninit<V>], &mut [V]>(&mut buf_slice[offset..][..64]),
117-
mask.cast(),
118-
);
119-
}
120-
}
121-
122-
for idx in ((indices_len / LANE_COUNT) * LANE_COUNT)..indices_len {
123-
unsafe {
124-
buf_slice
125-
.get_unchecked_mut(idx)
126-
.write(values[indices[idx].as_()]);
127-
}
128-
}
129-
130-
unsafe {
131-
buffer.set_len(indices_len);
132-
}
133-
134-
buffer.freeze()
135-
}
136-
13772
#[cfg(test)]
13873
mod tests {
13974
use super::take_portable_simd;

vortex-compute/src/take/mod.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
//! Take function.
55
66
mod buffer;
7-
mod slice;
7+
pub mod slice;
88

99
/// Function for taking based on indices (which can have different representations).
1010
pub trait Take<Indices: ?Sized> {

vortex-compute/src/take/slice/avx2.rs

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,15 @@
11
// SPDX-License-Identifier: Apache-2.0
22
// SPDX-FileCopyrightText: Copyright the Vortex contributors
33

4+
//! Take function implementations on slices using AVX2 SIMD.
5+
46
#![cfg(any(target_arch = "x86_64", target_arch = "x86"))]
57

68
use vortex_buffer::Buffer;
79
use vortex_dtype::NativePType;
810
use vortex_dtype::UnsignedPType;
911

12+
/// Takes the specified indices into a new [`Buffer`] using AVX2 SIMD.
1013
#[allow(dead_code, unused_variables, reason = "TODO(connor): Implement this")]
1114
#[inline]
1215
pub fn take_avx2<T: NativePType, I: UnsignedPType>(buffer: &[T], indices: &[I]) -> Buffer<T> {

vortex-compute/src/take/slice/mod.rs

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,16 @@
11
// SPDX-License-Identifier: Apache-2.0
22
// SPDX-FileCopyrightText: Copyright the Vortex contributors
33

4+
//! Take function implementations on slices.
5+
46
use vortex_buffer::Buffer;
57
use vortex_dtype::NativePType;
68
use vortex_dtype::UnsignedPType;
79

810
use crate::take::Take;
911

10-
mod avx2;
11-
mod portable;
12+
pub mod avx2;
13+
pub mod portable;
1214

1315
/// Specialized implementation for non-nullable indices.
1416
impl<T: NativePType, I: UnsignedPType> Take<[I]> for &[T] {

vortex-compute/src/take/slice/portable.rs

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
11
// SPDX-License-Identifier: Apache-2.0
22
// SPDX-FileCopyrightText: Copyright the Vortex contributors
33

4+
//! Take function implementations on slices using `portable_simd`.
5+
46
#![cfg(vortex_nightly)]
57

68
use std::mem::MaybeUninit;
@@ -16,6 +18,7 @@ use vortex_dtype::NativePType;
1618
use vortex_dtype::PType;
1719
use vortex_dtype::UnsignedPType;
1820

21+
/// Takes the specified indices into a new [`Buffer`] using portable SIMD.
1922
#[inline]
2023
pub fn take_portable<T, I>(buffer: &[T], indices: &[I]) -> Buffer<T>
2124
where
@@ -44,7 +47,7 @@ where
4447
///
4548
/// Returns a `Buffer<T>` where each element corresponds to `values[indices[i]]`.
4649
#[multiversion(targets("x86_64+avx2", "x86_64+avx", "aarch64+neon"))]
47-
fn take_portable_simd<T, I, const LANE_COUNT: usize>(values: &[T], indices: &[I]) -> Buffer<T>
50+
pub fn take_portable_simd<T, I, const LANE_COUNT: usize>(values: &[T], indices: &[I]) -> Buffer<T>
4851
where
4952
T: NativePType + simd::SimdElement,
5053
I: UnsignedPType + simd::SimdElement,

0 commit comments

Comments
 (0)