Skip to content

Commit ab6dec6

Browse files
authored
Enable vectorization of std::rotate on ARM64 (microsoft#5845)
1 parent d806de4 commit ab6dec6

File tree

2 files changed

+108
-2
lines changed

2 files changed

+108
-2
lines changed

stl/inc/xutility

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -97,7 +97,7 @@ _STL_DISABLE_CLANG_WARNINGS
9797
#define _VECTORIZED_REPLACE _VECTORIZED_FOR_X64_X86
9898
#define _VECTORIZED_REVERSE _VECTORIZED_FOR_X64_X86
9999
#define _VECTORIZED_REVERSE_COPY _VECTORIZED_FOR_X64_X86
100-
#define _VECTORIZED_ROTATE _VECTORIZED_FOR_X64_X86
100+
#define _VECTORIZED_ROTATE _VECTORIZED_FOR_X64_X86_ARM64
101101
#define _VECTORIZED_SEARCH _VECTORIZED_FOR_X64_X86
102102
#define _VECTORIZED_SEARCH_N _VECTORIZED_FOR_X64_X86
103103
#define _VECTORIZED_SWAP_RANGES _VECTORIZED_FOR_X64_X86_ARM64

stl/src/vector_algorithms.cpp

Lines changed: 107 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -250,9 +250,113 @@ void* __cdecl __std_swap_ranges_trivially_swappable(
250250

251251
} // extern "C"
252252

253-
#ifndef _M_ARM64
254253
namespace {
255254
namespace _Rotating {
255+
#ifdef _M_ARM64
256+
void __forceinline _Swap_3_ranges(void* _First1, void* const _Last1, void* _First2, void* _First3) noexcept {
257+
if (_Byte_length(_First1, _Last1) >= 64) {
258+
constexpr size_t _Mask_64 = ~((static_cast<size_t>(1) << 6) - 1);
259+
const void* _Stop_at = _First1;
260+
_Advance_bytes(_Stop_at, _Byte_length(_First1, _Last1) & _Mask_64);
261+
do {
262+
const uint8x16_t _Val1Lo1 = vld1q_u8(static_cast<uint8_t*>(_First1) + 0);
263+
const uint8x16_t _Val1Lo2 = vld1q_u8(static_cast<uint8_t*>(_First1) + 16);
264+
const uint8x16_t _Val1Hi1 = vld1q_u8(static_cast<uint8_t*>(_First1) + 32);
265+
const uint8x16_t _Val1Hi2 = vld1q_u8(static_cast<uint8_t*>(_First1) + 48);
266+
const uint8x16_t _Val2Lo1 = vld1q_u8(static_cast<uint8_t*>(_First2) + 0);
267+
const uint8x16_t _Val2Lo2 = vld1q_u8(static_cast<uint8_t*>(_First2) + 16);
268+
const uint8x16_t _Val2Hi1 = vld1q_u8(static_cast<uint8_t*>(_First2) + 32);
269+
const uint8x16_t _Val2Hi2 = vld1q_u8(static_cast<uint8_t*>(_First2) + 48);
270+
const uint8x16_t _Val3Lo1 = vld1q_u8(static_cast<uint8_t*>(_First3) + 0);
271+
const uint8x16_t _Val3Lo2 = vld1q_u8(static_cast<uint8_t*>(_First3) + 16);
272+
const uint8x16_t _Val3Hi1 = vld1q_u8(static_cast<uint8_t*>(_First3) + 32);
273+
const uint8x16_t _Val3Hi2 = vld1q_u8(static_cast<uint8_t*>(_First3) + 48);
274+
vst1q_u8(static_cast<uint8_t*>(_First1) + 0, _Val2Lo1);
275+
vst1q_u8(static_cast<uint8_t*>(_First1) + 16, _Val2Lo2);
276+
vst1q_u8(static_cast<uint8_t*>(_First1) + 32, _Val2Hi1);
277+
vst1q_u8(static_cast<uint8_t*>(_First1) + 48, _Val2Hi2);
278+
vst1q_u8(static_cast<uint8_t*>(_First2) + 0, _Val3Lo1);
279+
vst1q_u8(static_cast<uint8_t*>(_First2) + 16, _Val3Lo2);
280+
vst1q_u8(static_cast<uint8_t*>(_First2) + 32, _Val3Hi1);
281+
vst1q_u8(static_cast<uint8_t*>(_First2) + 48, _Val3Hi2);
282+
vst1q_u8(static_cast<uint8_t*>(_First3) + 0, _Val1Lo1);
283+
vst1q_u8(static_cast<uint8_t*>(_First3) + 16, _Val1Lo2);
284+
vst1q_u8(static_cast<uint8_t*>(_First3) + 32, _Val1Hi1);
285+
vst1q_u8(static_cast<uint8_t*>(_First3) + 48, _Val1Hi2);
286+
_Advance_bytes(_First1, 64);
287+
_Advance_bytes(_First2, 64);
288+
_Advance_bytes(_First3, 64);
289+
} while (_First1 != _Stop_at);
290+
}
291+
292+
if (_Byte_length(_First1, _Last1) >= 32) {
293+
const uint8x16_t _Val1Lo = vld1q_u8(static_cast<uint8_t*>(_First1) + 0);
294+
const uint8x16_t _Val1Hi = vld1q_u8(static_cast<uint8_t*>(_First1) + 16);
295+
const uint8x16_t _Val2Lo = vld1q_u8(static_cast<uint8_t*>(_First2) + 0);
296+
const uint8x16_t _Val2Hi = vld1q_u8(static_cast<uint8_t*>(_First2) + 16);
297+
const uint8x16_t _Val3Lo = vld1q_u8(static_cast<uint8_t*>(_First3) + 0);
298+
const uint8x16_t _Val3Hi = vld1q_u8(static_cast<uint8_t*>(_First3) + 16);
299+
vst1q_u8(static_cast<uint8_t*>(_First1) + 0, _Val2Lo);
300+
vst1q_u8(static_cast<uint8_t*>(_First1) + 16, _Val2Hi);
301+
vst1q_u8(static_cast<uint8_t*>(_First2) + 0, _Val3Lo);
302+
vst1q_u8(static_cast<uint8_t*>(_First2) + 16, _Val3Hi);
303+
vst1q_u8(static_cast<uint8_t*>(_First3) + 0, _Val1Lo);
304+
vst1q_u8(static_cast<uint8_t*>(_First3) + 16, _Val1Hi);
305+
_Advance_bytes(_First1, 32);
306+
_Advance_bytes(_First2, 32);
307+
_Advance_bytes(_First3, 32);
308+
}
309+
310+
if (_Byte_length(_First1, _Last1) >= 16) {
311+
const uint8x16_t _Val1 = vld1q_u8(static_cast<uint8_t*>(_First1));
312+
const uint8x16_t _Val2 = vld1q_u8(static_cast<uint8_t*>(_First2));
313+
const uint8x16_t _Val3 = vld1q_u8(static_cast<uint8_t*>(_First3));
314+
vst1q_u8(static_cast<uint8_t*>(_First1), _Val2);
315+
vst1q_u8(static_cast<uint8_t*>(_First2), _Val3);
316+
vst1q_u8(static_cast<uint8_t*>(_First3), _Val1);
317+
_Advance_bytes(_First1, 16);
318+
_Advance_bytes(_First2, 16);
319+
_Advance_bytes(_First3, 16);
320+
}
321+
322+
if (_Byte_length(_First1, _Last1) >= 8) {
323+
const uint8x8_t _Val1 = vld1_u8(static_cast<uint8_t*>(_First1));
324+
const uint8x8_t _Val2 = vld1_u8(static_cast<uint8_t*>(_First2));
325+
const uint8x8_t _Val3 = vld1_u8(static_cast<uint8_t*>(_First3));
326+
vst1_u8(static_cast<uint8_t*>(_First1), _Val2);
327+
vst1_u8(static_cast<uint8_t*>(_First2), _Val3);
328+
vst1_u8(static_cast<uint8_t*>(_First3), _Val1);
329+
_Advance_bytes(_First1, 8);
330+
_Advance_bytes(_First2, 8);
331+
_Advance_bytes(_First3, 8);
332+
}
333+
334+
if (_Byte_length(_First1, _Last1) >= 4) {
335+
uint32x2_t _Val1 = vdup_n_u32(0);
336+
uint32x2_t _Val2 = vdup_n_u32(0);
337+
uint32x2_t _Val3 = vdup_n_u32(0);
338+
_Val1 = vld1_lane_u32(static_cast<uint32_t*>(_First1), _Val1, 0);
339+
_Val2 = vld1_lane_u32(static_cast<uint32_t*>(_First2), _Val2, 0);
340+
_Val3 = vld1_lane_u32(static_cast<uint32_t*>(_First3), _Val3, 0);
341+
vst1_lane_u32(static_cast<uint32_t*>(_First1), _Val2, 0);
342+
vst1_lane_u32(static_cast<uint32_t*>(_First2), _Val3, 0);
343+
vst1_lane_u32(static_cast<uint32_t*>(_First3), _Val1, 0);
344+
_Advance_bytes(_First1, 4);
345+
_Advance_bytes(_First2, 4);
346+
_Advance_bytes(_First3, 4);
347+
}
348+
349+
auto _First1c = static_cast<unsigned char*>(_First1);
350+
auto _First2c = static_cast<unsigned char*>(_First2);
351+
auto _First3c = static_cast<unsigned char*>(_First3);
352+
for (; _First1c != _Last1; ++_First1c, ++_First2c, ++_First3c) {
353+
const unsigned char _Ch = *_First1c;
354+
*_First1c = *_First2c;
355+
*_First2c = *_First3c;
356+
*_First3c = _Ch;
357+
}
358+
}
359+
#else // ^^^ defined(_M_ARM64) / !defined(_M_ARM64) vvv
256360
void _Swap_3_ranges(void* _First1, void* const _Last1, void* _First2, void* _First3) noexcept {
257361
#ifndef _M_ARM64EC
258362
constexpr size_t _Mask_32 = ~((static_cast<size_t>(1) << 5) - 1);
@@ -346,6 +450,7 @@ namespace {
346450
*_First3c = _Ch;
347451
}
348452
}
453+
#endif // ^^^ !defined(_M_ARM64) ^^^
349454

350455
constexpr size_t _Buf_size = 512;
351456

@@ -418,6 +523,7 @@ __declspec(noalias) void __stdcall __std_rotate(void* _First, void* const _Mid,
418523

419524
} // extern "C"
420525

526+
#ifndef _M_ARM64
421527
namespace {
422528
namespace _Reversing {
423529
#ifdef _M_ARM64EC

0 commit comments

Comments
 (0)