|
| 1 | +// SPDX-License-Identifier: Apache-2.0 |
| 2 | +// SPDX-FileCopyrightText: Copyright the Vortex contributors |
| 3 | + |
| 4 | +#![cfg(vortex_nightly)] |
| 5 | + |
| 6 | +use std::mem::MaybeUninit; |
| 7 | +use std::mem::transmute; |
| 8 | +use std::simd; |
| 9 | +use std::simd::num::SimdUint; |
| 10 | + |
| 11 | +use multiversion::multiversion; |
| 12 | +use vortex_buffer::Alignment; |
| 13 | +use vortex_buffer::Buffer; |
| 14 | +use vortex_buffer::BufferMut; |
| 15 | +use vortex_dtype::NativePType; |
| 16 | +use vortex_dtype::PType; |
| 17 | +use vortex_dtype::UnsignedPType; |
| 18 | + |
| 19 | +#[inline] |
| 20 | +pub fn take_portable<T, I>(buffer: &[T], indices: &[I]) -> Buffer<T> |
| 21 | +where |
| 22 | + T: NativePType + simd::SimdElement, |
| 23 | + I: UnsignedPType + simd::SimdElement, |
| 24 | +{ |
| 25 | + if T::PTYPE == PType::F16 { |
| 26 | + // Since Rust does not actually support 16-bit floats, we first reinterpret the data as |
| 27 | + // `u16` integers. |
| 28 | + let u16_slice: &[u16] = |
| 29 | + unsafe { std::slice::from_raw_parts(buffer.as_ptr() as *const u16, buffer.len()) }; |
| 30 | + |
| 31 | + let taken_u16 = take_portable_simd::<u16, I, SIMD_WIDTH>(u16_slice, indices); |
| 32 | + let taken_f16 = taken_u16.cast_into::<T>(); |
| 33 | + |
| 34 | + taken_f16 |
| 35 | + } else { |
| 36 | + take_portable_simd::<T, I, SIMD_WIDTH>(buffer, indices) |
| 37 | + } |
| 38 | +} |
| 39 | + |
| 40 | +/// Takes elements from an array using SIMD indexing. |
| 41 | +/// |
| 42 | +/// Performs a gather operation that takes values at specified indices and returns them in a new |
| 43 | +/// buffer. Uses SIMD instructions to process `LANE_COUNT` indices in parallel. |
| 44 | +/// |
| 45 | +/// Returns a `Buffer<T>` where each element corresponds to `values[indices[i]]`. |
| 46 | +#[multiversion(targets("x86_64+avx2", "x86_64+avx", "aarch64+neon"))] |
| 47 | +fn take_portable_simd<T, I, const LANE_COUNT: usize>(values: &[T], indices: &[I]) -> Buffer<T> |
| 48 | +where |
| 49 | + T: NativePType + simd::SimdElement, |
| 50 | + I: UnsignedPType + simd::SimdElement, |
| 51 | + simd::LaneCount<LANE_COUNT>: simd::SupportedLaneCount, |
| 52 | + simd::Simd<I, LANE_COUNT>: SimdUint<Cast<usize> = simd::Simd<usize, LANE_COUNT>>, |
| 53 | +{ |
| 54 | + let indices_len = indices.len(); |
| 55 | + |
| 56 | + let mut buffer = BufferMut::<T>::with_capacity_aligned( |
| 57 | + indices_len, |
| 58 | + Alignment::of::<simd::Simd<T, LANE_COUNT>>(), |
| 59 | + ); |
| 60 | + |
| 61 | + let buf_slice = buffer.spare_capacity_mut(); |
| 62 | + |
| 63 | + for chunk_idx in 0..(indices_len / LANE_COUNT) { |
| 64 | + let offset = chunk_idx * LANE_COUNT; |
| 65 | + let mask = simd::Mask::from_bitmask(u64::MAX); |
| 66 | + let codes_chunk = simd::Simd::<I, LANE_COUNT>::from_slice(&indices[offset..]); |
| 67 | + |
| 68 | + let selection = simd::Simd::gather_select( |
| 69 | + values, |
| 70 | + mask, |
| 71 | + codes_chunk.cast::<usize>(), |
| 72 | + simd::Simd::<T, LANE_COUNT>::default(), |
| 73 | + ); |
| 74 | + |
| 75 | + unsafe { |
| 76 | + selection.store_select_unchecked( |
| 77 | + transmute::<&mut [MaybeUninit<T>], &mut [T]>(&mut buf_slice[offset..][..64]), |
| 78 | + mask.cast(), |
| 79 | + ); |
| 80 | + } |
| 81 | + } |
| 82 | + |
| 83 | + for idx in ((indices_len / LANE_COUNT) * LANE_COUNT)..indices_len { |
| 84 | + unsafe { |
| 85 | + buf_slice |
| 86 | + .get_unchecked_mut(idx) |
| 87 | + .write(values[indices[idx].as_()]); |
| 88 | + } |
| 89 | + } |
| 90 | + |
| 91 | + unsafe { |
| 92 | + buffer.set_len(indices_len); |
| 93 | + } |
| 94 | + |
| 95 | + buffer.freeze() |
| 96 | +} |
| 97 | + |
| 98 | +#[cfg(test)] |
| 99 | +mod tests { |
| 100 | + use super::take_portable_simd; |
| 101 | + |
| 102 | + #[test] |
| 103 | + fn test_take_out_of_bounds() { |
| 104 | + let indices = vec![2_000_000u32; 64]; |
| 105 | + let values = vec![1i32]; |
| 106 | + |
| 107 | + let result = take_portable_simd::<i32, u32, 64>(&values, &indices); |
| 108 | + assert_eq!(result.as_slice(), [0i32; 64]); |
| 109 | + } |
| 110 | +} |
0 commit comments