Skip to content
Merged
Show file tree
Hide file tree
Changes from 9 commits
Commits
Show all changes
22 commits
Select commit Hold shift + click to select a range
9111971
save temp code
tiger100256-hu Oct 9, 2025
50fbad7
fix compile issue
tiger100256-hu Oct 10, 2025
3584a28
debug function and accuracy
tiger100256-hu Oct 11, 2025
5583cbd
fix bug in code
tiger100256-hu Oct 13, 2025
1009f61
remove debug print
tiger100256-hu Oct 13, 2025
a8306f5
add test case for sdpa sinkinput
tiger100256-hu Oct 13, 2025
3f50597
clean code
tiger100256-hu Oct 13, 2025
013e6f5
remove debug code in cpu transformation pipeline and format test code
tiger100256-hu Oct 14, 2025
9964440
fix format and compiling issue
tiger100256-hu Oct 14, 2025
55cbf93
disable sink input on none x86 paltform
tiger100256-hu Oct 14, 2025
1ff6d8b
do not use default value and use const float* as sink type
tiger100256-hu Oct 14, 2025
bb8249e
fix build issue on ARM
tiger100256-hu Oct 15, 2025
0d20fc2
compare max and sink outside reduce_max
tiger100256-hu Oct 15, 2025
9bb170d
revert modifaction in cmake, no use default value now
tiger100256-hu Oct 15, 2025
4d66b01
check sink input precision if it is fp32
tiger100256-hu Oct 15, 2025
f5920a4
add maybe_unused and remove unused code
tiger100256-hu Oct 15, 2025
7b5e843
use c++ version std::exp, only one input float
tiger100256-hu Oct 15, 2025
195bb04
don't need to check the origin precsion, it will add reorder later
tiger100256-hu Oct 16, 2025
5fe160b
use the at(index, broadcast_true) interface of plaintensor
tiger100256-hu Oct 17, 2025
c41c5b0
add shapeinfer check for sink input
tiger100256-hu Oct 17, 2025
50bf292
fix issue in shapeinfer
tiger100256-hu Oct 17, 2025
061c496
add comment about sink_input
tiger100256-hu Oct 21, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -54,10 +54,10 @@ function(_generate_dispatcher)
string(APPEND DISP_CONTENT
"namespace ${_arch} {\n ${SIGNATURE}\; \n}\n")
endforeach()

## remove default value in SIGNATURE
string(REGEX REPLACE "[ ]*=[ ]*[a-zA-Z0-9_]+[ ]*" "" SIGNATURE_NO_DEFAULT ${SIGNATURE})
string(APPEND DISP_CONTENT
"namespace ${XARCH_CURRENT_NAMESPACE} {\n\n${SIGNATURE} {\n")

"namespace ${XARCH_CURRENT_NAMESPACE} {\n\n${SIGNATURE_NO_DEFAULT} {\n")
foreach(_arch IN LISTS XARCH_SET)
string(APPEND DISP_CONTENT
" if (${_CPU_CHECK_${_arch}}) {\n return ${_arch}::${CALL_LINE}\;\n }\n")
Expand Down Expand Up @@ -98,6 +98,9 @@ function(_generate_call_line_from_signature SIGNATURE RESULT_NAME)
string(REPLACE ")" "" _args ${_args})
string(REPLACE "," ";" _args ${_args}) # now it's list
foreach(_arg_elem ${_args})
string(STRIP ${_arg_elem} _arg_elem)
## remove default value
string(REGEX REPLACE "=.*$" "" _arg_elem ${_arg_elem})
string(REGEX MATCH "[a-zA-Z0-9_]*[ ]*$" _arg_elem "${_arg_elem}")
list(APPEND _arg_names ${_arg_elem})
endforeach()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1407,7 +1407,8 @@ static void mha_single_token_kernel(const ov::intel_cpu::PlainTensor& query,
ov::intel_cpu::PlainTensor& head_sum,
size_t key_group_size,
size_t value_group_size,
bool quant_key_by_channel) {
bool quant_key_by_channel,
const ov::intel_cpu::PlainTensor& sink_input) {
ov::intel_cpu::PlainTensor causal_mask;
bool select_nfltmax_at_0 = false;
auto B = query.size(0);
Expand Down Expand Up @@ -1591,6 +1592,8 @@ static void mha_single_token_kernel(const ov::intel_cpu::PlainTensor& query,
attn_mask_ptr = reinterpret_cast<uint8_t*>(&attention_mask.at<T>({b, h, pq, 0}, true));
}
uint8_t* cmask_ptr = causal_mask ? &causal_mask.at<uint8_t>({b, h, pq, 0}, true) : nullptr;

auto sink = sink_input.safe_ptr<T3>(b, h, pq);
attn_softmax_kernel<T3>(buf_attn_w.ptr<T3>(b, h, pq),
buf_attn_w.ptr<T3>(b, h, pq),
d_scale,
Expand All @@ -1601,7 +1604,9 @@ static void mha_single_token_kernel(const ov::intel_cpu::PlainTensor& query,
ncausal,
cur_kv_len,
attn_mask_prec,
precision);
precision,
0,
sink);
});

// attn_w * V
Expand Down Expand Up @@ -1719,7 +1724,8 @@ void mha_single_token(const ov::intel_cpu::PlainTensor& query,
ov::intel_cpu::PlainTensor& head_sum,
size_t key_group_size,
size_t value_group_size,
bool quant_key_by_channel) {
bool quant_key_by_channel,
const ov::intel_cpu::PlainTensor& sink_input) {
if (query.get_precision() == ov::element::bf16) {
if (present_key.get_precision() == ov::element::u8) {
mha_single_token_kernel<ov::bfloat16, uint8_t, float>(query,
Expand All @@ -1739,7 +1745,8 @@ void mha_single_token(const ov::intel_cpu::PlainTensor& query,
head_sum,
key_group_size,
value_group_size,
quant_key_by_channel);
quant_key_by_channel,
sink_input);
} else {
mha_single_token_kernel<ov::bfloat16, ov::bfloat16, float>(query,
present_key,
Expand All @@ -1758,7 +1765,8 @@ void mha_single_token(const ov::intel_cpu::PlainTensor& query,
head_sum,
key_group_size,
value_group_size,
quant_key_by_channel);
quant_key_by_channel,
sink_input);
}
} else if (query.get_precision() == ov::element::f16) {
#if defined(__ARM_FEATURE_FP16_VECTOR_ARITHMETIC)
Expand Down Expand Up @@ -1803,7 +1811,8 @@ void mha_single_token(const ov::intel_cpu::PlainTensor& query,
head_sum,
key_group_size,
value_group_size,
quant_key_by_channel);
quant_key_by_channel,
sink_input);
} else {
mha_single_token_kernel<ov::float16, ov::float16, float>(query,
present_key,
Expand All @@ -1822,7 +1831,8 @@ void mha_single_token(const ov::intel_cpu::PlainTensor& query,
head_sum,
key_group_size,
value_group_size,
quant_key_by_channel);
quant_key_by_channel,
sink_input);
}
#endif
} else if (query.get_precision() == ov::element::f32) {
Expand All @@ -1844,7 +1854,8 @@ void mha_single_token(const ov::intel_cpu::PlainTensor& query,
head_sum,
key_group_size,
value_group_size,
quant_key_by_channel);
quant_key_by_channel,
sink_input);
} else if (present_key.get_precision() == ov::element::f16) {
mha_single_token_kernel<float, ov::float16, float>(query,
present_key,
Expand All @@ -1863,7 +1874,8 @@ void mha_single_token(const ov::intel_cpu::PlainTensor& query,
head_sum,
key_group_size,
value_group_size,
quant_key_by_channel);
quant_key_by_channel,
sink_input);
} else {
mha_single_token_kernel<float, float, float>(query,
present_key,
Expand All @@ -1882,7 +1894,8 @@ void mha_single_token(const ov::intel_cpu::PlainTensor& query,
head_sum,
key_group_size,
value_group_size,
quant_key_by_channel);
quant_key_by_channel,
sink_input);
}
} else {
OPENVINO_THROW("Unsupported precision: ", query.get_precision());
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ void mha_single_token(const ov::intel_cpu::PlainTensor& query,
ov::intel_cpu::PlainTensor& head_sum,
size_t key_group_size,
size_t value_group_size,
bool quant_key_by_channel);
bool quant_key_by_channel,
const ov::intel_cpu::PlainTensor& sink_input);
Comment on lines -29 to +38
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please add a short comment about the purpose of this sink input and what should be set if there is no sink input.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

already add comment, please help to check it, thanks


} // namespace ov::Extensions::Cpu::XARCH
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,8 @@ void attn_softmax(void* a,
size_t total_size,
[[maybe_unused]] ov::element::Type precision,
ov::element::Type attn_mask_prec,
ov::element::Type dst_precision) {
ov::element::Type dst_precision,
void* sink) {
#if defined(__ARM_FEATURE_FP16_VECTOR_ARITHMETIC)
if (precision == ov::element::f16) {
auto _a = reinterpret_cast<ov::float16*>(a);
Expand All @@ -49,6 +50,7 @@ void attn_softmax(void* a,
#endif
auto* _a = reinterpret_cast<float*>(a);
auto* _alibi = reinterpret_cast<float*>(alibi);
auto* _sink = reinterpret_cast<float*>(sink);
attn_softmax_kernel<float>(_a,
a_dst,
scale,
Expand All @@ -59,7 +61,9 @@ void attn_softmax(void* a,
len,
total_size,
attn_mask_prec,
dst_precision);
dst_precision,
0,
_sink);
}

} // namespace ov::Extensions::Cpu::XARCH
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,6 @@ void attn_softmax(void* a,
size_t total_size,
ov::element::Type precision,
ov::element::Type attn_mask_prec,
ov::element::Type dst_precision);

ov::element::Type dst_precision,
void* sink = nullptr);
} // namespace ov::Extensions::Cpu::XARCH
Original file line number Diff line number Diff line change
Expand Up @@ -191,7 +191,8 @@ inline void scale_add2_reduce_max(float* a,
bool select_nfltmax_at_0, // true: 0 in mask set -FLT_MAX
size_t size,
float alibi_slope,
float& max) {
float& max,
float* sink = nullptr) {
size_t i = 0;
#if defined(HAVE_AVX512F)
auto v_max0 = _mm512_set1_ps(std::numeric_limits<float>::lowest());
Expand Down Expand Up @@ -289,6 +290,11 @@ inline void scale_add2_reduce_max(float* a,

i += (size - i);
}
if (sink != nullptr) {
__mmask16 mask = 1;
v_a = _mm512_maskz_loadu_ps(mask, sink);
v_max0 = _mm512_mask_max_ps(v_max0, mask, v_a, v_max0);
}

v_max0 = _mm512_max_ps(v_max0, v_max1);
v_max2 = _mm512_max_ps(v_max2, v_max3);
Expand Down Expand Up @@ -475,7 +481,8 @@ inline void scale_add2_reduce_max(ov::float16* a,
bool select_nfltmax_at_0, // true: 0 in mask set -FLT_MAX
size_t size,
float alibi_slope,
ov::float16& max) {
ov::float16& max,
float* sink = nullptr) {
size_t i = 0;
# if defined(HAVE_SVE)
svfloat16_t v_max = svdup_n_f16(static_cast<float16_t>(-FLT_MAX));
Expand Down Expand Up @@ -671,7 +678,7 @@ static inline void exp_ps_avx512(__m512& src) {
}
#endif

inline void exp_reduce_sum(float* a, const float max, const size_t size, float& sum) {
inline void exp_reduce_sum(float* a, const float max, const size_t size, float& sum, float* sink = nullptr) {
size_t i = 0;
#if defined(HAVE_AVX512F)
__m512 v_a;
Expand All @@ -696,6 +703,13 @@ inline void exp_reduce_sum(float* a, const float max, const size_t size, float&

i += (size - i);
}
if (sink != nullptr) {
__mmask16 mask = 1;
v_a = _mm512_maskz_loadu_ps(mask, sink);
v_a = _mm512_sub_ps(v_a, v_max);
exp_ps_avx512(v_a);
v_sum = _mm512_mask_add_ps(v_sum, mask, v_a, v_sum);
}
sum = _mm512_reduce_add_ps(v_sum);
#elif defined(HAVE_AVX2)
__m256 v_a;
Expand All @@ -721,6 +735,14 @@ inline void exp_reduce_sum(float* a, const float max, const size_t size, float&

i += (size - i);
}
if (sink != nullptr) {
auto mask = get_mask(1);
v_a = _mm256_maskload_ps(sink, mask);
v_a = _mm256_sub_ps(v_a, v_max);
exp_ps_avx2(v_a);
v_a = _mm256_blendv_ps(_mm256_setzero_ps(), v_a, _mm256_castsi256_ps(mask));
v_sum = _mm256_add_ps(v_a, v_sum);
}
hsum(v_sum);
sum = _mm256_cvtss_f32(v_sum);
#elif defined(OPENVINO_ARCH_ARM64)
Expand Down Expand Up @@ -1068,7 +1090,8 @@ inline void attn_softmax_kernel(T* a,
size_t total_size,
ov::element::Type attn_mask_prec,
ov::element::Type dst_precision,
float alibi_slope = 0);
float alibi_slope = 0,
T* sink = nullptr);

template <>
inline void attn_softmax_kernel<float>(float* a,
Expand All @@ -1082,13 +1105,14 @@ inline void attn_softmax_kernel<float>(float* a,
size_t total_size,
ov::element::Type attn_mask_prec,
ov::element::Type dst_precision,
float alibi_slope) {
float alibi_slope,
float* sink) {
using func_fp32_type =
void (*)(float*, float, const float*, const float*, const uint8_t*, bool, size_t, float, float&);
void (*)(float*, float, const float*, const float*, const uint8_t*, bool, size_t, float, float&, float*);
using func_bf16_type =
void (*)(float*, float, const float*, const ov::bfloat16*, const uint8_t*, bool, size_t, float, float&);
void (*)(float*, float, const float*, const ov::bfloat16*, const uint8_t*, bool, size_t, float, float&, float*);
using func_f16_type =
void (*)(float*, float, const float*, const ov::float16*, const uint8_t*, bool, size_t, float, float&);
void (*)(float*, float, const float*, const ov::float16*, const uint8_t*, bool, size_t, float, float&, float*);
static constexpr func_fp32_type funcs_fp32[] = {scale_add2_reduce_max<false, false, false>,
scale_add2_reduce_max<false, false, true>,
scale_add2_reduce_max<false, true, false>,
Expand Down Expand Up @@ -1124,7 +1148,8 @@ inline void attn_softmax_kernel<float>(float* a,
select_nfltmax_at_0,
len,
alibi_slope,
max);
max,
sink);
} else if (attn_mask_prec == ov::element::bf16) {
funcs_bf16[dispatch](a,
scale,
Expand All @@ -1134,7 +1159,8 @@ inline void attn_softmax_kernel<float>(float* a,
select_nfltmax_at_0,
len,
alibi_slope,
max);
max,
sink);
} else {
funcs_f16[dispatch](a,
scale,
Expand All @@ -1144,12 +1170,13 @@ inline void attn_softmax_kernel<float>(float* a,
select_nfltmax_at_0,
len,
alibi_slope,
max);
max,
sink);
}

float sum = 0.0f;
// exp sum
exp_reduce_sum(a, max, len, sum);
exp_reduce_sum(a, max, len, sum, sink);
// divide sum
float scalar = 1.0f / sum;
if (dst_precision == ov::element::f32) {
Expand Down Expand Up @@ -1185,7 +1212,8 @@ inline void attn_softmax_kernel<ov::float16>(ov::float16* a,
size_t total_size,
ov::element::Type attn_mask_prec,
ov::element::Type dst_precision,
float alibi_slope) {
float alibi_slope,
ov::float16* sink = nullptr) {
using func_fp32_type = void (*)(ov::float16*,
float,
const ov::float16*,
Expand Down
Loading
Loading