Skip to content

Commit b234f19

Browse files
Maajid Khan/maajidkhan.n@fujitsu.comdivya2108abhishek-iitmadras
authored andcommitted
Extending the Pytorch vec backend for SVE ISA (ARM)
The intention with this contribution is to add support for SVE backend for Vec with vector length of 512 in the Aten vectorization for CPU backend which can be benefitted by any ARM architecture supported CPU's that supports SVE. Currently this flow can be tested on A64FX (Fugaku) servers which has SVE 512 hardware support. Signed-off-by: maajidkhann <[email protected]> Co-authored-by: Divya Kotadiya <[email protected]> Co-authored-by: Abhishek Kumar <[email protected]>
1 parent 1cd4199 commit b234f19

28 files changed

+2362
-11
lines changed

CMakeLists.txt

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1139,6 +1139,22 @@ int main() {
11391139
endif()
11401140
endif()
11411141

1142+
if(CMAKE_SYSTEM_PROCESSOR MATCHES "aarch64")
1143+
include(CheckCSourceCompiles)
1144+
set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} -march=armv8-a+sve")
1145+
check_c_source_compiles("#include <arm_sve.h>
1146+
int main() {
1147+
svfloat64_t a;
1148+
a = svdup_n_f64(0);
1149+
return 0;
1150+
}" COMPILER_HAS_ARM_SVE)
1151+
1152+
if(COMPILER_HAS_ARM_SVE)
1153+
string(APPEND CMAKE_CXX_FLAGS " -DCOMPILER_HAS_ARM_SVE")
1154+
endif()
1155+
set(CMAKE_C_FLAGS ${ORIGINAL_CMAKE_C_FLAGS})
1156+
endif()
1157+
11421158
# Add code coverage flags to supported compilers
11431159
if(USE_CPP_CODE_COVERAGE)
11441160
if("${CMAKE_CXX_COMPILER_ID}" STREQUAL "GNU")

aten/src/ATen/Version.cpp

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -105,6 +105,11 @@ std::string get_cpu_capability() {
105105
return "DEFAULT";
106106
case native::CPUCapability::ZVECTOR:
107107
return "Z VECTOR";
108+
#elif defined(HAVE_SVE_CPU_DEFINITION)
109+
case native::CPUCapability::DEFAULT:
110+
return "DEFAULT";
111+
case native::CPUCapability::SVE512:
112+
return "SVE512";
108113
#else
109114
case native::CPUCapability::DEFAULT:
110115
return "NO AVX";

aten/src/ATen/cpu/vec/functional_base.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -78,7 +78,7 @@ struct VecReduceAllSIMD<float, Op> {
7878
#endif // defined(CPU_CAPABILITY_AVX512)
7979
#endif // defined(__GNUC__) && (__GNUC__ > 5) && !defined(_MSC_VER) && !defined(C10_MOBILE)
8080

81-
#if defined(__aarch64__) && !defined(C10_MOBILE) && !defined(__CUDACC__)
81+
#if defined(__aarch64__) && !defined(C10_MOBILE) && !defined(__CUDACC__) && !defined(CPU_CAPABILITY_SVE)
8282
template <typename Op>
8383
struct VecReduceAllSIMD<float, Op> {
8484
static inline float apply(const Op& vec_fun, const Vectorized<float>& acc_vec) {

aten/src/ATen/cpu/vec/intrinsics.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -41,3 +41,7 @@
4141
/* GCC-compatible compiler, targeting PowerPC with SPE */
4242
#include <spe.h>
4343
#endif
44+
/* CLANG and GCC-compatible compilers, targeting ARM with SVE */
45+
#if defined(COMPILER_HAS_ARM_SVE)
46+
#include <arm_sve.h>
47+
#endif
Lines changed: 64 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,64 @@
1+
#pragma once
2+
3+
#include <ATen/cpu/vec/intrinsics.h>
4+
5+
#include <ATen/cpu/vec/vec_base.h>
6+
7+
#if defined(CPU_CAPABILITY_SVE)
8+
9+
// Macro for vector width
10+
// Define the data type of VLS(vector-length specific).
11+
typedef svbool_t vls_pred_t __attribute__((arm_sve_vector_bits(VECTOR_WIDTH * 8)));
12+
typedef svint8_t vls_int8_t __attribute__((arm_sve_vector_bits(VECTOR_WIDTH * 8)));
13+
typedef svint16_t vls_int16_t __attribute__((arm_sve_vector_bits(VECTOR_WIDTH * 8)));
14+
typedef svint32_t vls_int32_t __attribute__((arm_sve_vector_bits(VECTOR_WIDTH * 8)));
15+
typedef svint64_t vls_int64_t __attribute__((arm_sve_vector_bits(VECTOR_WIDTH * 8)));
16+
typedef svuint8_t vls_uint8_t __attribute__((arm_sve_vector_bits(VECTOR_WIDTH * 8)));
17+
typedef svuint16_t vls_uint16_t __attribute__((arm_sve_vector_bits(VECTOR_WIDTH * 8)));
18+
typedef svuint32_t vls_uint32_t __attribute__((arm_sve_vector_bits(VECTOR_WIDTH * 8)));
19+
typedef svuint64_t vls_uint64_t __attribute__((arm_sve_vector_bits(VECTOR_WIDTH * 8)));
20+
typedef svfloat16_t vls_float16_t __attribute__((arm_sve_vector_bits(VECTOR_WIDTH * 8)));
21+
typedef svfloat32_t vls_float32_t __attribute__((arm_sve_vector_bits(VECTOR_WIDTH * 8)));
22+
typedef svfloat64_t vls_float64_t __attribute__((arm_sve_vector_bits(VECTOR_WIDTH * 8)));
23+
24+
#define ptrue svptrue_b8()
25+
#define ZERO_S8 svdup_n_s8(0)
26+
#define ZERO_S16 svdup_n_s16(0)
27+
#define ZERO_S32 svdup_n_s32(0)
28+
#define ZERO_S64 svdup_n_s64(0)
29+
#define ZERO_U8 svdup_n_u8(0)
30+
#define ZERO_U16 svdup_n_u16(0)
31+
#define ZERO_U32 svdup_n_u32(0)
32+
#define ZERO_U64 svdup_n_u64(0)
33+
#define ZERO_F16 svdup_n_f16(0.f)
34+
#define ZERO_F32 svdup_n_f32(0.f)
35+
#define ZERO_F64 svdup_n_f64(0.0)
36+
#define ONE_S8 svdup_n_s8(1)
37+
#define ONE_S16 svdup_n_s16(1)
38+
#define ONE_S32 svdup_n_s32(1)
39+
#define ONE_S64 svdup_n_s64(1)
40+
#define ONE_U8 svdup_n_u8(1)
41+
#define ONE_U16 svdup_n_u16(1)
42+
#define ONE_U32 svdup_n_u32(1)
43+
#define ONE_U64 svdup_n_u64(1)
44+
#define ONE_F16 svdup_n_f16(1.f)
45+
#define ONE_F32 svdup_n_f32(1.f)
46+
#define ONE_F64 svdup_n_f64(1.0)
47+
#define ALL_S8_TRUE_MASK svdup_n_s8(0xff)
48+
#define ALL_S8_FALSE_MASK svdup_n_s8(0x0)
49+
#define ALL_S16_TRUE_MASK svdup_n_s16(0xffff)
50+
#define ALL_S16_FALSE_MASK svdup_n_s16(0x0)
51+
#define ALL_S32_TRUE_MASK svdup_n_s32(0xffffffff)
52+
#define ALL_S32_FALSE_MASK svdup_n_s32(0x0)
53+
#define ALL_S64_TRUE_MASK svdup_n_s64(0xffffffffffffffff)
54+
#define ALL_S64_FALSE_MASK svdup_n_s64(0x0)
55+
#define ALL_U8_TRUE_MASK svdup_n_u8(0x01)
56+
#define ALL_U8_FALSE_MASK svdup_n_u8(0x00)
57+
#define ALL_F16_TRUE_MASK svreinterpret_f16_s16(ALL_S16_TRUE_MASK)
58+
#define ALL_F16_FALSE_MASK svreinterpret_f16_s16(ALL_S16_FALSE_MASK)
59+
#define ALL_F32_TRUE_MASK svreinterpret_f32_s32(ALL_S32_TRUE_MASK)
60+
#define ALL_F32_FALSE_MASK svreinterpret_f32_s32(ALL_S32_FALSE_MASK)
61+
#define ALL_F64_TRUE_MASK svreinterpret_f64_s64(ALL_S64_TRUE_MASK)
62+
#define ALL_F64_FALSE_MASK svreinterpret_f64_s64(ALL_S64_FALSE_MASK)
63+
64+
#endif // defined(CPU_CAPABILITY_SVE)
Lines changed: 176 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,176 @@
1+
#pragma once
2+
3+
// DO NOT DEFINE STATIC DATA IN THIS HEADER!
4+
// See Note [Do not compile initializers with SVE]
5+
6+
#include <ATen/cpu/vec/intrinsics.h>
7+
8+
#include <ATen/cpu/vec/vec_base.h>
9+
#include <ATen/cpu/vec/sve/sve_helper.h>
10+
11+
#if defined(CPU_CAPABILITY_SVE)
12+
#include <ATen/cpu/vec/sve/vec_float.h>
13+
#include <ATen/cpu/vec/sve/vec_double.h>
14+
#include <ATen/cpu/vec/sve/vec_int.h>
15+
#include <ATen/cpu/vec/sve/vec_qint.h>
16+
#endif
17+
18+
namespace at {
19+
namespace vec {
20+
// Note [CPU_CAPABILITY namespace]
21+
// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
22+
// This header, and all of its subheaders, will be compiled with
23+
// different architecture flags for each supported set of vector
24+
// intrinsics. So we need to make sure they aren't inadvertently
25+
// linked together. We do this by declaring objects in an `inline
26+
// namespace` which changes the name mangling, but can still be
27+
// accessed as `at::vec`.
28+
inline namespace CPU_CAPABILITY {
29+
30+
#if defined(CPU_CAPABILITY_SVE)
31+
32+
// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ CAST ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
33+
34+
template<>
35+
inline Vectorized<float> cast<float, double>(const Vectorized<double>& src) {
36+
return svreinterpret_f32_f64(src);
37+
}
38+
39+
template<>
40+
inline Vectorized<double> cast<double, float>(const Vectorized<float>& src) {
41+
return svreinterpret_f64_f32(src);
42+
}
43+
44+
#define DEFINE_FLOAT_INT_CAST(int_t, int_bit, float_t, float_bit) \
45+
template<> \
46+
inline Vectorized<int_t> cast<int_t, float_t>(const Vectorized<float_t>& src) { \
47+
return svreinterpret_s##int_bit##_f##float_bit(src); \
48+
} \
49+
template<> \
50+
inline Vectorized<float_t> cast<float_t, int_t>(const Vectorized<int_t>& src) { \
51+
return svreinterpret_f##float_bit##_s##int_bit(src); \
52+
}
53+
54+
DEFINE_FLOAT_INT_CAST(int64_t, 64, double, 64)
55+
DEFINE_FLOAT_INT_CAST(int32_t, 32, double, 64)
56+
DEFINE_FLOAT_INT_CAST(int16_t, 16, double, 64)
57+
DEFINE_FLOAT_INT_CAST(int64_t, 64, float, 32)
58+
DEFINE_FLOAT_INT_CAST(int32_t, 32, float, 32)
59+
DEFINE_FLOAT_INT_CAST(int16_t, 16, float, 32)
60+
61+
// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ GATHER ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
62+
63+
template<int64_t scale = 1>
64+
std::enable_if_t<scale == 1 || scale == 2 || scale == 4 || scale == 8, Vectorized<double>>
65+
inline gather(const double* base_addr, const Vectorized<int64_t>& vindex_) {
66+
svint64_t vindex = svasrd_n_s64_x(ptrue, svmul_s64_x(ptrue, vindex_, svdup_n_s64(scale)), 3);
67+
return svld1_gather_s64index_f64(ptrue, base_addr, vindex);
68+
}
69+
70+
template<int64_t scale = 1>
71+
std::enable_if_t<scale == 1 || scale == 2 || scale == 4 || scale == 8, Vectorized<float>>
72+
inline gather(const float* base_addr, const Vectorized<int32_t>& vindex_) {
73+
svint32_t vindex = svasrd_n_s32_x(ptrue, svmul_s32_x(ptrue, vindex_, svdup_n_s32(scale)), 2);
74+
return svld1_gather_s32index_f32(ptrue, base_addr, vindex);
75+
}
76+
77+
// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ MASK GATHER ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
78+
79+
template<int64_t scale = 1>
80+
std::enable_if_t<scale == 1 || scale == 2 || scale == 4 || scale == 8, Vectorized<double>>
81+
inline mask_gather(const Vectorized<double>& src, const double* base_addr,
82+
const Vectorized<int64_t>& vindex_, const Vectorized<double>& mask_) {
83+
svbool_t mask = svcmpeq_s64(ptrue, svreinterpret_s64_f64(mask_),
84+
ALL_S64_TRUE_MASK);
85+
svint64_t vindex = svasrd_n_s64_x(ptrue, svmul_s64_x(ptrue, vindex_, svdup_n_s64(scale)), 3);
86+
return svsel_f64(mask, svld1_gather_s64index_f64(mask, base_addr, vindex), src);
87+
}
88+
89+
template<int64_t scale = 1>
90+
std::enable_if_t<scale == 1 || scale == 2 || scale == 4 || scale == 8, Vectorized<float>>
91+
inline mask_gather(const Vectorized<float>& src, const float* base_addr,
92+
const Vectorized<int32_t>& vindex_, const Vectorized<float>& mask_) {
93+
svbool_t mask = svcmpeq_s32(ptrue, svreinterpret_s32_f32(mask_),
94+
ALL_S32_TRUE_MASK);
95+
svint32_t vindex = svasrd_n_s32_x(ptrue, svmul_s32_x(ptrue, vindex_, svdup_n_s32(scale)), 2);
96+
return svsel_f32(mask, svld1_gather_s32index_f32(mask, base_addr, vindex), src);
97+
}
98+
99+
// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ CONVERT ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
100+
101+
// Only works for inputs in the range: [-2^51, 2^51]
102+
// From: https://stackoverflow.com/a/41148578
103+
template<>
104+
Vectorized<int64_t>
105+
inline convert_to_int_of_same_size<double>(const Vectorized<double> &src) {
106+
svfloat64_t x = svadd_f64_x(ptrue, src, svdup_n_f64(0x0018000000000000));
107+
return svsub_s64_x(ptrue,
108+
svreinterpret_s64_f64(x),
109+
svreinterpret_s64_f64(svdup_n_f64(0x0018000000000000)));
110+
}
111+
112+
template<>
113+
Vectorized<int32_t>
114+
inline convert_to_int_of_same_size<float>(const Vectorized<float> &src) {
115+
return svcvt_s32_f32_x(ptrue, src);
116+
}
117+
118+
// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ INTERLEAVE ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
119+
120+
template <>
121+
std::pair<Vectorized<double>, Vectorized<double>>
122+
inline interleave2<double>(const Vectorized<double>& a, const Vectorized<double>& b) {
123+
// inputs:
124+
// a = {a0, a1, a3, a3}
125+
// b = {b0, b1, b2, b3}
126+
// group cols crossing lanes:
127+
// return {a0, b0, a1, b1}
128+
// {a2, b2, a3, b3}
129+
return std::make_pair(Vectorized<double>(svzip1_f64(a, b)),
130+
Vectorized<double>(svzip2_f64(a, b)));
131+
}
132+
133+
template <>
134+
std::pair<Vectorized<float>, Vectorized<float>>
135+
inline interleave2<float>(const Vectorized<float>& a, const Vectorized<float>& b) {
136+
// inputs:
137+
// a = {a0, a1, a2, a3, a4, a5, a6, a7}
138+
// b = {b0, b1, b2, b3, b4, b5, b6, b7}
139+
// group cols crossing lanes:
140+
// return {a0, b0, a1, b1, a2, b2, a3, b3}
141+
// {a4, b4, a5, b5, a6, b6, a7, b7}
142+
return std::make_pair(Vectorized<float>(svzip1_f32(a, b)),
143+
Vectorized<float>(svzip2_f32(a, b)));
144+
}
145+
146+
// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ DEINTERLEAVE ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
147+
148+
template <>
149+
std::pair<Vectorized<double>, Vectorized<double>>
150+
inline deinterleave2<double>(const Vectorized<double>& a, const Vectorized<double>& b) {
151+
// inputs:
152+
// a = {a0, b0, a1, b1}
153+
// b = {a2, b2, a3, b3}
154+
// swap lanes:
155+
// return {a0, a1, a2, a3}
156+
// {b0, b1, b2, b3}
157+
return std::make_pair(Vectorized<double>(svuzp1_f64(a, b)),
158+
Vectorized<double>(svuzp2_f64(a, b)));
159+
}
160+
161+
template <>
162+
std::pair<Vectorized<float>, Vectorized<float>>
163+
inline deinterleave2<float>(const Vectorized<float>& a, const Vectorized<float>& b) {
164+
// inputs:
165+
// a = {a0, b0, a1, b1, a2, b2, a3, b3}
166+
// b = {a4, b4, a5, b5, a6, b6, a7, b7}
167+
// swap lanes:
168+
// return {a0, a1, a2, a3, a4, a5, a6, a7}
169+
// {b0, b1, b2, b3, b4, b5, b6, b7}
170+
return std::make_pair(Vectorized<float>(svuzp1_f32(a, b)),
171+
Vectorized<float>(svuzp2_f32(a, b)));
172+
}
173+
174+
#endif // defined(CPU_CAPABILITY_SVE)
175+
176+
}}}

0 commit comments

Comments
 (0)