diff --git a/src/plugins/intel_cpu/src/nodes/kernels/scaled_attn/executor_pa.cpp b/src/plugins/intel_cpu/src/nodes/kernels/scaled_attn/executor_pa.cpp index 66b68323043ee7..e678260d83835f 100644 --- a/src/plugins/intel_cpu/src/nodes/kernels/scaled_attn/executor_pa.cpp +++ b/src/plugins/intel_cpu/src/nodes/kernels/scaled_attn/executor_pa.cpp @@ -907,7 +907,8 @@ struct MHAHelper { new_causal, rnd_up(cur_kv_len, _block_size) - start_idx, precision_of::value, - precision_of::value); + precision_of::value, + nullptr); memset(score, 0, sizeof(DATA_TYPE) * start_idx); } else { @@ -929,6 +930,7 @@ struct MHAHelper { rnd_up(cur_kv_len, _block_size), precision_of::value, precision_of::value, + nullptr, alibi_slope); } if (score_output && m >= q_start_idx_score) { @@ -1083,7 +1085,8 @@ struct MHAHelper { new_causal, rnd_up(cur_kv_len, _block_size) - start_idx, precision_of::value, - precision_of::value); + precision_of::value, + nullptr); memset(score, 0, sizeof(DATA_TYPE) * start_idx); } else { @@ -1105,6 +1108,7 @@ struct MHAHelper { rnd_up(cur_kv_len, _block_size), precision_of::value, precision_of::value, + nullptr, alibi_slope); } if (score_output && m >= q_start_idx_score) { @@ -1223,6 +1227,7 @@ struct MHAHelper { cur_kv_len, ov::element::f32, ov::element::f32, + nullptr, alibi_slope); if (score_output) { // aligned to cache line to avoid false sharing @@ -1400,6 +1405,7 @@ struct MHAHelper { cur_kv_len, ov::element::f32, ov::element::f32, + nullptr, alibi_slope); }; diff --git a/src/plugins/intel_cpu/src/nodes/kernels/scaled_attn/mha_single_token.cpp b/src/plugins/intel_cpu/src/nodes/kernels/scaled_attn/mha_single_token.cpp index 10c55293e72a29..3d3325064eefe3 100644 --- a/src/plugins/intel_cpu/src/nodes/kernels/scaled_attn/mha_single_token.cpp +++ b/src/plugins/intel_cpu/src/nodes/kernels/scaled_attn/mha_single_token.cpp @@ -1407,7 +1407,8 @@ static void mha_single_token_kernel(const ov::intel_cpu::PlainTensor& query, ov::intel_cpu::PlainTensor& head_sum, size_t key_group_size, size_t value_group_size, - bool quant_key_by_channel) { + bool quant_key_by_channel, + const ov::intel_cpu::PlainTensor& sink_input) { ov::intel_cpu::PlainTensor causal_mask; bool select_nfltmax_at_0 = false; auto B = query.size(0); @@ -1591,6 +1592,10 @@ static void mha_single_token_kernel(const ov::intel_cpu::PlainTensor& query, attn_mask_ptr = reinterpret_cast(&attention_mask.at({b, h, pq, 0}, true)); } uint8_t* cmask_ptr = causal_mask ? &causal_mask.at({b, h, pq, 0}, true) : nullptr; + float* sink = nullptr; + if (sink_input) { + sink = &sink_input.at({b, h, pq, 0}, true); + } attn_softmax_kernel(buf_attn_w.ptr(b, h, pq), buf_attn_w.ptr(b, h, pq), d_scale, @@ -1601,7 +1606,8 @@ static void mha_single_token_kernel(const ov::intel_cpu::PlainTensor& query, ncausal, cur_kv_len, attn_mask_prec, - precision); + precision, + sink); }); // attn_w * V @@ -1719,7 +1725,8 @@ void mha_single_token(const ov::intel_cpu::PlainTensor& query, ov::intel_cpu::PlainTensor& head_sum, size_t key_group_size, size_t value_group_size, - bool quant_key_by_channel) { + bool quant_key_by_channel, + const ov::intel_cpu::PlainTensor& sink_input) { if (query.get_precision() == ov::element::bf16) { if (present_key.get_precision() == ov::element::u8) { mha_single_token_kernel(query, @@ -1739,7 +1746,8 @@ void mha_single_token(const ov::intel_cpu::PlainTensor& query, head_sum, key_group_size, value_group_size, - quant_key_by_channel); + quant_key_by_channel, + sink_input); } else { mha_single_token_kernel(query, present_key, @@ -1758,7 +1766,8 @@ void mha_single_token(const ov::intel_cpu::PlainTensor& query, head_sum, key_group_size, value_group_size, - quant_key_by_channel); + quant_key_by_channel, + sink_input); } } else if (query.get_precision() == ov::element::f16) { #if defined(__ARM_FEATURE_FP16_VECTOR_ARITHMETIC) @@ -1780,7 +1789,8 @@ void mha_single_token(const ov::intel_cpu::PlainTensor& query, head_sum, key_group_size, value_group_size, - quant_key_by_channel); + quant_key_by_channel, + sink_input); } else { OPENVINO_THROW("Unsupported precision: ", present_key.get_precision()); } @@ -1803,7 +1813,8 @@ void mha_single_token(const ov::intel_cpu::PlainTensor& query, head_sum, key_group_size, value_group_size, - quant_key_by_channel); + quant_key_by_channel, + sink_input); } else { mha_single_token_kernel(query, present_key, @@ -1822,7 +1833,8 @@ void mha_single_token(const ov::intel_cpu::PlainTensor& query, head_sum, key_group_size, value_group_size, - quant_key_by_channel); + quant_key_by_channel, + sink_input); } #endif } else if (query.get_precision() == ov::element::f32) { @@ -1844,7 +1856,8 @@ void mha_single_token(const ov::intel_cpu::PlainTensor& query, head_sum, key_group_size, value_group_size, - quant_key_by_channel); + quant_key_by_channel, + sink_input); } else if (present_key.get_precision() == ov::element::f16) { mha_single_token_kernel(query, present_key, @@ -1863,7 +1876,8 @@ void mha_single_token(const ov::intel_cpu::PlainTensor& query, head_sum, key_group_size, value_group_size, - quant_key_by_channel); + quant_key_by_channel, + sink_input); } else { mha_single_token_kernel(query, present_key, @@ -1882,7 +1896,8 @@ void mha_single_token(const ov::intel_cpu::PlainTensor& query, head_sum, key_group_size, value_group_size, - quant_key_by_channel); + quant_key_by_channel, + sink_input); } } else { OPENVINO_THROW("Unsupported precision: ", query.get_precision()); diff --git a/src/plugins/intel_cpu/src/nodes/kernels/scaled_attn/mha_single_token.hpp b/src/plugins/intel_cpu/src/nodes/kernels/scaled_attn/mha_single_token.hpp index b7a317c4a5e9c1..8c518c96d665ef 100644 --- a/src/plugins/intel_cpu/src/nodes/kernels/scaled_attn/mha_single_token.hpp +++ b/src/plugins/intel_cpu/src/nodes/kernels/scaled_attn/mha_single_token.hpp @@ -9,6 +9,14 @@ namespace ov::Extensions::Cpu::XARCH { +// if there is on sink input, please use a default PlainTensor as sink_input which don't contain any data, +// the softmax operation will use the default formula: +// a[i] = exp(a[i] - max(a)); +// result[i] = a[i] / sum(a); +// if the sink_input contain data, +// the softmax formula become: +// a[i] = exp(a[i] - max(a, sink)); +// result[i] = a[i] / sum(a, sink); void mha_single_token(const ov::intel_cpu::PlainTensor& query, const ov::intel_cpu::PlainTensor& present_key, const ov::intel_cpu::PlainTensor& present_value, @@ -26,6 +34,7 @@ void mha_single_token(const ov::intel_cpu::PlainTensor& query, ov::intel_cpu::PlainTensor& head_sum, size_t key_group_size, size_t value_group_size, - bool quant_key_by_channel); + bool quant_key_by_channel, + const ov::intel_cpu::PlainTensor& sink_input); } // namespace ov::Extensions::Cpu::XARCH diff --git a/src/plugins/intel_cpu/src/nodes/kernels/scaled_attn/softmax.cpp b/src/plugins/intel_cpu/src/nodes/kernels/scaled_attn/softmax.cpp index 70abab4717b03c..e04c67291a5f91 100644 --- a/src/plugins/intel_cpu/src/nodes/kernels/scaled_attn/softmax.cpp +++ b/src/plugins/intel_cpu/src/nodes/kernels/scaled_attn/softmax.cpp @@ -28,7 +28,8 @@ void attn_softmax(void* a, size_t total_size, [[maybe_unused]] ov::element::Type precision, ov::element::Type attn_mask_prec, - ov::element::Type dst_precision) { + ov::element::Type dst_precision, + const float* sink) { #if defined(__ARM_FEATURE_FP16_VECTOR_ARITHMETIC) if (precision == ov::element::f16) { auto _a = reinterpret_cast(a); @@ -43,7 +44,8 @@ void attn_softmax(void* a, len, total_size, attn_mask_prec, - dst_precision); + dst_precision, + sink); return; } #endif @@ -59,7 +61,8 @@ void attn_softmax(void* a, len, total_size, attn_mask_prec, - dst_precision); + dst_precision, + sink); } } // namespace ov::Extensions::Cpu::XARCH diff --git a/src/plugins/intel_cpu/src/nodes/kernels/scaled_attn/softmax.hpp b/src/plugins/intel_cpu/src/nodes/kernels/scaled_attn/softmax.hpp index 509f0c64980150..fd4b866badccb3 100644 --- a/src/plugins/intel_cpu/src/nodes/kernels/scaled_attn/softmax.hpp +++ b/src/plugins/intel_cpu/src/nodes/kernels/scaled_attn/softmax.hpp @@ -20,6 +20,6 @@ void attn_softmax(void* a, size_t total_size, ov::element::Type precision, ov::element::Type attn_mask_prec, - ov::element::Type dst_precision); - + ov::element::Type dst_precision, + const float* sink); } // namespace ov::Extensions::Cpu::XARCH diff --git a/src/plugins/intel_cpu/src/nodes/kernels/scaled_attn/softmax_kernel.hpp b/src/plugins/intel_cpu/src/nodes/kernels/scaled_attn/softmax_kernel.hpp index 733765a47329ee..3bbe0fc3e8f3c6 100644 --- a/src/plugins/intel_cpu/src/nodes/kernels/scaled_attn/softmax_kernel.hpp +++ b/src/plugins/intel_cpu/src/nodes/kernels/scaled_attn/softmax_kernel.hpp @@ -1068,6 +1068,7 @@ inline void attn_softmax_kernel(T* a, size_t total_size, ov::element::Type attn_mask_prec, ov::element::Type dst_precision, + const float* sink, float alibi_slope = 0); template <> @@ -1082,6 +1083,7 @@ inline void attn_softmax_kernel(float* a, size_t total_size, ov::element::Type attn_mask_prec, ov::element::Type dst_precision, + const float* sink, float alibi_slope) { using func_fp32_type = void (*)(float*, float, const float*, const float*, const uint8_t*, bool, size_t, float, float&); @@ -1148,8 +1150,14 @@ inline void attn_softmax_kernel(float* a, } float sum = 0.0f; + if (sink != nullptr) { + max = max > (*sink) ? max : (*sink); + } // exp sum exp_reduce_sum(a, max, len, sum); + if (sink != nullptr) { + sum += std::exp(*sink - max); + } // divide sum float scalar = 1.0f / sum; if (dst_precision == ov::element::f32) { @@ -1185,6 +1193,7 @@ inline void attn_softmax_kernel(ov::float16* a, size_t total_size, ov::element::Type attn_mask_prec, ov::element::Type dst_precision, + const float* sink, float alibi_slope) { using func_fp32_type = void (*)(ov::float16*, float, diff --git a/src/plugins/intel_cpu/src/nodes/scaled_attn.cpp b/src/plugins/intel_cpu/src/nodes/scaled_attn.cpp index f0639e24910898..c0dcf1b74aca65 100644 --- a/src/plugins/intel_cpu/src/nodes/scaled_attn.cpp +++ b/src/plugins/intel_cpu/src/nodes/scaled_attn.cpp @@ -122,13 +122,19 @@ struct MHAKernel { return result; } - void softmax(float* a, int len) { + void softmax(float* a, int len, const float* sink) { float max = *std::max_element(a, a + len); + if (sink != nullptr) { + max = max > (*sink) ? max : (*sink); + } float sum = 0.0F; for (int i = 0; i < len; i++) { a[i] = exp(a[i] - max); sum += a[i]; } + if (sink != nullptr) { + sum += exp((*sink) - max); + } float scale = 1.0F / sum; for (int i = 0; i < len; i++) { a[i] *= scale; @@ -164,6 +170,7 @@ struct MHAKernel { PlainTensor& output_emb, bool has_out_transpose, bool auto_causal, + PlainTensor& sink_input, float d_scale = 0.0F) { auto B = query.size(0); auto H = query.size(1); @@ -221,8 +228,12 @@ struct MHAKernel { } } + float* sink = nullptr; + if (sink_input) { + sink = &sink_input.at({b, h, m, 0}, true); + } // softmax - softmax(attn_score.data(), ncausal); + softmax(attn_score.data(), ncausal, sink); // linearly combine value word_vec.assign(head_size_v, 0.0F); @@ -391,6 +402,7 @@ struct MHAKernel { PlainTensor& output_emb, bool has_out_transpose, bool auto_causal, + PlainTensor& sink_input, float d_scale = 0.0F) { const auto B = query.size(0); const auto H = query.size(1); @@ -458,6 +470,11 @@ struct MHAKernel { // apply attention mask & sofmax auto ncausal = auto_causal ? (kv_len - q_len + m + 1) : kv_len; auto* score = weight_score.ptr(ithr, 0, m - m_start); + float* sink = nullptr; + if (sink_input) { + sink = &sink_input.at({b, h, m, 0}, true); + } + attn_softmax(reinterpret_cast(score), reinterpret_cast(score), d_scale, @@ -469,7 +486,8 @@ struct MHAKernel { kv_len, precision_of::value, precision_of::value, - precision_of::value); + precision_of::value, + sink); } auto* w_ptr = reinterpret_cast(weight_score.ptr(ithr, 0, 0, 0)); float* fp32_out_ptr = nullptr; @@ -542,6 +560,7 @@ struct MHAKernel { PlainTensor& output_emb, bool has_out_transpose, bool auto_causal, + PlainTensor& sink_input, float d_scale = 0.0F) { auto head_size = query.size(3); if (d_scale == 0.0F) { @@ -557,6 +576,7 @@ struct MHAKernel { output_emb, has_out_transpose, auto_causal, + sink_input, d_scale); } }; @@ -601,6 +621,7 @@ struct MHAKernel { PlainTensor& output_emb, bool has_out_transpose, bool auto_causal, + [[maybe_unused]] PlainTensor& sink_input, float d_scale = 0.0F) { auto B = query.size(0); auto H = query.size(1); @@ -686,7 +707,8 @@ struct MHAKernel { kv_len, precision, precision, - precision); + precision, + nullptr); } arm_compute::TensorInfo outInfo; arm_compute::Tensor outTensor; @@ -758,6 +780,7 @@ struct MHAKernel { PlainTensor& output_emb, bool has_out_transpose, bool auto_causal, + PlainTensor& sink_input, float d_scale = 0.0F) { auto B = query.size(0); auto H = query.size(1); @@ -853,6 +876,10 @@ struct MHAKernel { for (size_t m = m_start; m < m_end; m++) { // apply attention mask & sofmax auto ncausal = auto_causal ? (kv_len - q_len + m + 1) : kv_len; + float* sink = nullptr; + if (sink_input) { + sink = &sink_input.at({b, h, m, 0}, true); + } attn_softmax(reinterpret_cast(qk + (m - m_start) * qk_m_stride), qk + (m - m_start) * qk_m_stride, d_scale, @@ -864,7 +891,8 @@ struct MHAKernel { kv_len, ov::element::f32, ov::element::f32, - ov::element::f32); + ov::element::f32, + sink); } mlas_sgemm("N", "N", @@ -922,7 +950,8 @@ struct MHASingleToken { bool auto_causal, float d_scale, const PlainTensor& k_scale_zp, - const PlainTensor& v_scale_zp) { + const PlainTensor& v_scale_zp, + const PlainTensor& sink_input) { auto B = query.size(0); auto H = query.size(1); auto q_len = query.size(2); @@ -948,7 +977,8 @@ struct MHASingleToken { m_head_sum, m_key_group_size, m_value_group_size, - m_quant_key_by_channel); + m_quant_key_by_channel, + sink_input); } }; @@ -1003,6 +1033,7 @@ struct ScaledDotProductAttention::AttentionExecutor : public ScaledDotProductAtt PlainTensor beam_table; // i32[B, max_kvLen] PlainTensor attn_mask; PlainTensor output_emb(output); + PlainTensor sink_input; float scale_input = 0.0F; size_t B = 0; size_t L1 = 0; @@ -1065,6 +1096,9 @@ struct ScaledDotProductAttention::AttentionExecutor : public ScaledDotProductAtt SV = v_input.size(3); L0 = present_key.size(2) - L1; auto Hk = k_input.size(1); + if (input_num > 5) { + sink_input.reset(inputs[5]); + } if (fuse_concat) { k_input.assert_dims({B, Hk, L1, S}); @@ -1114,6 +1148,7 @@ struct ScaledDotProductAttention::AttentionExecutor : public ScaledDotProductAtt bool use_one_token = L1 == 1 || (fuse_concat && L0 > 0); if (!use_one_token) { // multi-token version + kernel(strm, q_input, k_input, @@ -1123,6 +1158,7 @@ struct ScaledDotProductAttention::AttentionExecutor : public ScaledDotProductAtt output_emb, has_out_transpose, auto_causal, + sink_input, scale_input); } else { // 1-token version @@ -1141,7 +1177,8 @@ struct ScaledDotProductAttention::AttentionExecutor : public ScaledDotProductAtt auto_causal, scale_input, k_scale_zp, - v_scale_zp); + v_scale_zp, + sink_input); } } }; @@ -1214,6 +1251,13 @@ void ScaledDotProductAttention::initSupportedPrimitiveDescriptors() { if (orginSDPInputNumber > 4) { config.inConfs[nextPortIdx].setMemDesc( creatorsMap.at(LayoutType::ncsp)->createSharedDesc(ov::element::f32, getInputShapeAtPort(nextPortIdx))); + nextPortIdx++; + } + // sink_input + if (orginSDPInputNumber > 5) { + config.inConfs[nextPortIdx].setMemDesc( + creatorsMap.at(LayoutType::ncsp)->createSharedDesc(ov::element::f32, getInputShapeAtPort(nextPortIdx))); + nextPortIdx++; } if (m_config.config.fuse_concat) { diff --git a/src/plugins/intel_cpu/src/shape_inference/custom/scaled_attn.cpp b/src/plugins/intel_cpu/src/shape_inference/custom/scaled_attn.cpp index 1e494a513e2d42..8a9686fd77ba06 100644 --- a/src/plugins/intel_cpu/src/shape_inference/custom/scaled_attn.cpp +++ b/src/plugins/intel_cpu/src/shape_inference/custom/scaled_attn.cpp @@ -48,7 +48,7 @@ class SDPAShapeInfer : public ShapeInferEmptyPads { for (size_t i = 0; i < n_dims; i++) { output_dims[i] = query_dims[permute_axes[i]]; } - if (inputs_size == 7 && !m_config.is_causal) { + if (inputs_size >= 7 && !m_config.is_causal) { const auto& attn_mask_dims = input_shapes[3].get(); bool attn_mask_ok = true; auto attn_mask_dims_size = attn_mask_dims.size(); @@ -88,6 +88,55 @@ class SDPAShapeInfer : public ShapeInferEmptyPads { ov::intel_cpu::vec2str(cache_v_dims)); } } + if (inputs_size == 9) { + const auto& sink_dims = input_shapes[5].get(); + bool sink_ok = true; + auto weight_dims = output_dims; + auto weight_dims_size = weight_dims.size(); + if (weight_dims_size != sink_dims.size()) { + sink_ok = false; + } else { + weight_dims[3] = present_v_dims[length_index]; + auto check_broadcast = [](const size_t& target, const size_t& to) -> bool { + return any_of(target, to, 1U); + }; + sink_ok = sink_ok && check_broadcast(sink_dims[0], weight_dims[0]); + sink_ok = sink_ok && (sink_dims[1] == weight_dims[1]); + sink_ok = sink_ok && check_broadcast(sink_dims[2], weight_dims[2]); + if (sink_dims[weight_dims_size - 1] != 1) { + sink_ok = false; + }; + } + if (!sink_ok) { + const auto& cur_k_dims = input_shapes[1].get(); + const auto& cur_v_dims = input_shapes[2].get(); + const auto& attn_mask_dims = input_shapes[3].get(); + const auto& scale_dims = input_shapes[4].get(); + const auto& sink_dims = input_shapes[5].get(); + const auto& beam_idx_dims = input_shapes[6].get(); + const auto& cache_k_dims = input_shapes[7].get(); + const auto& cache_v_dims = input_shapes[8].get(); + OPENVINO_THROW("sink input do not match q and k,", + " query_dims:", + ov::intel_cpu::vec2str(query_dims), + " cur_k_dims:", + ov::intel_cpu::vec2str(cur_k_dims), + " cur_v_dims:", + ov::intel_cpu::vec2str(cur_v_dims), + " attn_mask_dims:", + ov::intel_cpu::vec2str(attn_mask_dims), + " scale_dims:", + ov::intel_cpu::vec2str(scale_dims), + " sink_dims:", + ov::intel_cpu::vec2str(sink_dims), + " beam_idx_dims:", + ov::intel_cpu::vec2str(beam_idx_dims), + " cache_k_dims:", + ov::intel_cpu::vec2str(cache_k_dims), + " cache_v_dims:", + ov::intel_cpu::vec2str(cache_v_dims)); + } + } // normal and fast path if (present_v_dims[3] == query_dims[3]) { diff --git a/src/plugins/intel_cpu/src/transformations/cpu_opset/common/pass/stateful_sdpa_fusion.cpp b/src/plugins/intel_cpu/src/transformations/cpu_opset/common/pass/stateful_sdpa_fusion.cpp index 6d8f021ae6fedc..dc3b08d29cfd23 100644 --- a/src/plugins/intel_cpu/src/transformations/cpu_opset/common/pass/stateful_sdpa_fusion.cpp +++ b/src/plugins/intel_cpu/src/transformations/cpu_opset/common/pass/stateful_sdpa_fusion.cpp @@ -59,6 +59,7 @@ StatefulSDPAFusion::StatefulSDPAFusion() { auto cur_q = any_input(); auto cur_k = any_input(); auto cur_v = any_input(); + auto atten_sink = any_input(); auto past_k = wrap_type(); auto past_v = wrap_type(); @@ -88,6 +89,9 @@ StatefulSDPAFusion::StatefulSDPAFusion() { auto sdp1 = wrap_type({cur_q, present_k, present_v, any_input()}); auto sdp2 = wrap_type({cur_q, present_k, present_v, any_input(), any_input()}); + // gpt-oss + auto sdp3 = wrap_type( + {cur_q, present_k, present_v, any_input(), any_input(), atten_sink}); // non-canonical q/k/v shape definitions, for example: [L, B, H, S]/[B, L, H, S] auto order_k = wrap_type(); @@ -102,13 +106,22 @@ StatefulSDPAFusion::StatefulSDPAFusion() { wrap_type({transpose_q, transpose_k, transpose_v, any_input()}); auto sdp_trans2 = wrap_type( {transpose_q, transpose_k, transpose_v, any_input(), any_input()}); + // gpt-oss + auto sdp_trans3 = wrap_type( + {transpose_q, transpose_k, transpose_v, any_input(), any_input(), atten_sink}); - auto sdp = sdp0 | sdp1 | sdp2 | sdp_trans0 | sdp_trans1 | sdp_trans2; + auto sdp = sdp0 | sdp1 | sdp2 | sdp3 | sdp_trans0 | sdp_trans1 | sdp_trans2 | sdp_trans3; ov::matcher_pass_callback callback = [=](Matcher& m) { const auto& pattern_map = m.get_pattern_value_map(); auto root = m.get_match_root(); +// only support sink input on x86 platform currently. +#ifndef OPENVINO_ARCH_X86_64 + if (pattern_map.count(atten_sink)) { + return false; + } +#endif // Check concat axes equality first const auto concat_k_node = ov::as_type_ptr(pattern_map.at(concat_k).get_node_shared_ptr()); const auto concat_v_node = ov::as_type_ptr(pattern_map.at(concat_v).get_node_shared_ptr()); diff --git a/src/plugins/intel_cpu/tests/functional/custom/subgraph_tests/src/x64/sdpa_sink_input.cpp b/src/plugins/intel_cpu/tests/functional/custom/subgraph_tests/src/x64/sdpa_sink_input.cpp new file mode 100644 index 00000000000000..451f30b1890a0a --- /dev/null +++ b/src/plugins/intel_cpu/tests/functional/custom/subgraph_tests/src/x64/sdpa_sink_input.cpp @@ -0,0 +1,298 @@ +// Copyright (C) 2025 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// + +#include +#include +#include +#include +#include "common_test_utils/data_utils.hpp" +#include "common_test_utils/include/common_test_utils/ov_tensor_utils.hpp" +#include "internal_properties.hpp" +#include "openvino/core/except.hpp" +#include "openvino/core/node_vector.hpp" +#include "openvino/core/type/float16.hpp" +#include "openvino/op/add.hpp" +#include "openvino/op/broadcast.hpp" +#include "openvino/op/concat.hpp" +#include "openvino/op/convert.hpp" +#include "openvino/op/convert_like.hpp" +#include "openvino/op/divide.hpp" +#include "openvino/op/gather.hpp" +#include "openvino/op/greater.hpp" +#include "openvino/op/matmul.hpp" +#include "openvino/op/multiply.hpp" +#include "openvino/op/paged_attention.hpp" +#include "openvino/op/parameter.hpp" +#include "openvino/op/range.hpp" +#include "openvino/op/reshape.hpp" +#include "openvino/op/scaled_dot_product_attention.hpp" +#include "openvino/op/select.hpp" +#include "openvino/op/shape_of.hpp" +#include "openvino/op/softmax.hpp" +#include "openvino/op/sqrt.hpp" +#include "openvino/op/squeeze.hpp" +#include "openvino/op/transpose.hpp" +#include "openvino/op/unsqueeze.hpp" +#include "openvino/runtime/infer_request.hpp" +#include "openvino/runtime/tensor.hpp" +#include "shared_test_classes/base/ov_subgraph.hpp" +#include "utils/cpu_test_utils.hpp" +#include "utils/general_utils.h" + +using namespace ov::test; +using namespace CPUTestUtils; +using namespace ov::op; + +namespace ov { +namespace test { +using InputShapes = std::vector; +using PagedAttnTestParams = std::tuple; + +class SdpaSinkTest : public testing::WithParamInterface, + virtual public ov::test::SubgraphBaseTest, + public CPUTestsBase { +public: + static std::string getTestCaseName(const testing::TestParamInfo& obj) { + const auto& [inType, inputShapes] = obj.param; + std::ostringstream result; + result << "IS="; + for (const auto& shape : inputShapes) { + result << ov::test::utils::partialShape2str({shape.first}) << "_"; + } + result << "TS="; + for (const auto& shape : inputShapes) { + result << "("; + if (!shape.second.empty()) { + for (const auto& itr : shape.second) { + result << ov::test::utils::vec2str(itr); + } + } + result << ")_"; + } + result << "Prc=" << inType; + + return result.str(); + } + static std::shared_ptr make_param(const PartialShape& pshape, + element::Type element_type, + const std::string& name) { + auto param = std::make_shared(element_type, pshape); + param->set_friendly_name(name); + param->get_output_tensor(0).set_names({name}); + return param; + } + virtual std::shared_ptr get_model(ov::element::Type data_type, + ov::Dimension::value_type head_size = 64, + ov::Dimension::value_type head_num = 8) { + // q, k, v use L,B,H,S layout + ov::PartialShape q_shape, kv_shape, past_shape, atten_mask_shape, scale_shape, sink_shape; + ov::ParameterVector inputParams; + past_shape = {-1, 1, head_num, head_size}; + q_shape = {-1, 1, static_cast(head_num), head_size}; + kv_shape = {-1, 1, head_num, head_size}; + atten_mask_shape = {1, head_num, -1, -1}; + scale_shape = {1}; + sink_shape = {1, head_num, 1, 1}; + + auto q = make_param(q_shape, data_type, "q"); + auto k = make_param(kv_shape, data_type, "k"); + auto v = make_param(kv_shape, data_type, "v"); + auto past_kv = make_param(past_shape, data_type, "past_kv"); + auto atten_mask = make_param(atten_mask_shape, data_type, "atten_mask"); + auto scale = make_param(scale_shape, data_type, "scale"); + auto sink = make_param(sink_shape, data_type, "sink"); + inputParams.push_back(q); + inputParams.push_back(k); + inputParams.push_back(v); + inputParams.push_back(atten_mask); + inputParams.push_back(scale); + inputParams.push_back(sink); + inputParams.push_back(past_kv); + auto var_k = + std::make_shared(ov::op::util::VariableInfo{past_shape, data_type, "pastk"}); + auto pastk = std::make_shared(inputParams[6], var_k); + pastk->set_friendly_name("pastk_r"); + auto var_v = + std::make_shared(ov::op::util::VariableInfo{past_shape, data_type, "pastv"}); + auto pastv = std::make_shared(inputParams[6], var_v); + pastv->set_friendly_name("pastv_r"); + std::vector transposeOrder{1, 2, 0, 3}; + auto preOrder = op::v0::Constant::create(ov::element::i32, {4}, transposeOrder); + std::shared_ptr q_in = std::make_shared(inputParams[0], preOrder); + + auto concat_axis = transposeOrder[2]; + auto beam_idx = std::make_shared(ov::element::i32, ov::PartialShape{-1}); + beam_idx->set_friendly_name("beam_idx"); + inputParams.push_back(beam_idx); + auto gatherK = + std::make_shared(pastk, + beam_idx, + op::v0::Constant::create(ov::element::i32, {1}, {transposeOrder[0]})); + auto gatherV = + std::make_shared(pastv, + beam_idx, + op::v0::Constant::create(ov::element::i32, {1}, {transposeOrder[0]})); + auto concatK = std::make_shared(OutputVector{gatherK, inputParams[1]}, concat_axis); + auto concatV = std::make_shared(OutputVector{gatherV, inputParams[2]}, concat_axis); + std::shared_ptr k_in = concatK; + std::shared_ptr v_in = concatV; + k_in = std::make_shared(k_in, preOrder); + v_in = std::make_shared(v_in, preOrder); + auto sdp = std::make_shared(q_in, k_in, v_in, atten_mask, scale, sink, false); + sdp->set_friendly_name("mha"); + auto pastk_assign = std::make_shared(concatK, var_k); + auto pastv_assign = std::make_shared(concatV, var_v); + pastk_assign->set_friendly_name("pastk_w"); + pastv_assign->set_friendly_name("pastv_w"); + auto get_reshape_order = [](const ov::PartialShape& qkv_shape, + const std::vector& transposeOrder) -> std::vector { + assert(transposeOrder.size() == 4); + auto H = qkv_shape[transposeOrder[1]].get_length(); + auto S = qkv_shape[transposeOrder[3]].get_length(); + return std::vector{0, static_cast(H * S)}; + }; + const auto reshapeOrder = get_reshape_order(q_shape, transposeOrder); + auto postOrder = + ov::op::v0::Constant::create(ov::element::i32, {4}, std::vector{2, 0, 1, 3}); // BHLS -> LBHS + auto transposeSDP = std::make_shared(sdp, postOrder); + + auto constReshape = ov::op::v0::Constant::create(ov::element::i32, {2}, reshapeOrder); + auto reshapeSDP = + std::make_shared(transposeSDP, + constReshape, + true); // use LBHS to better compare data between pa and sdpa + SinkVector sinks{pastk_assign, pastv_assign}; + ov::OutputVector results{reshapeSDP}; + auto model = std::make_shared(results, sinks, inputParams, "sdpa_model"); + return model; + } + + void SetUp() override { + const auto& [inType, inputShapes] = this->GetParam(); + targetDevice = ov::test::utils::DEVICE_CPU; + configuration[ov::hint::inference_precision.name()] = ov::element::f32; + if (inType == ElementType::bf16) { + configuration[ov::hint::inference_precision.name()] = ov::element::bf16; + configuration[ov::hint::kv_cache_precision.name()] = ov::element::f16; + rel_threshold = 0.02f; + abs_threshold = 0.02f; + } else if (inType == ElementType::f32) { + configuration[ov::hint::kv_cache_precision.name()] = ov::element::f32; + } else if (inType == ElementType::f16) { + configuration[ov::hint::kv_cache_precision.name()] = ov::element::f16; + } + init_input_shapes(inputShapes); + ov::ParameterVector inputParams; + + function = get_model(inType, 64, 8); + } + + virtual void generate(int idx, const std::vector& targetInputStaticShapes) { + inputs.clear(); + auto create_input = [this](std::shared_ptr param, ov::Shape shape, float val = 0) { + if (param->get_element_type() == ov::element::i32) { + ov::Tensor t{ov::element::i32, shape}; + auto size = shape[0]; + auto* p = static_cast(t.data()); + auto start = static_cast(val); + for (size_t i = 0; i < size; i++) { + p[i] = (start + i) % size; + } + inputs.insert({param, t}); + } else if (param->get_element_type() == ov::element::f32) { + ov::Tensor t{ov::element::f32, shape}; + utils::fill_data_random(static_cast(t.data()), t.get_size(), 2, -1, 10); + inputs.insert({param, t}); + } else if (param->get_element_type() == ov::element::f16) { + ov::Tensor t{ov::element::f16, shape}; + utils::fill_data_random(static_cast(t.data()), t.get_size(), 2, -1, 10); + inputs.insert({param, t}); + } else { + ASSERT_TRUE(param->get_element_type() == ov::element::bf16); + ov::Tensor t{ov::element::bf16, shape}; + utils::fill_data_random(static_cast(t.data()), t.get_size(), 2, -1, 10); + inputs.insert({param, t}); + } + }; + + // q, k, v, pastkv + create_input(function->get_parameters()[0], targetInputStaticShapes[0]); + create_input(function->get_parameters()[1], targetInputStaticShapes[0]); + create_input(function->get_parameters()[2], targetInputStaticShapes[0]); + create_input(function->get_parameters()[3], targetInputStaticShapes[2]); + create_input(function->get_parameters()[4], + function->get_parameters()[4]->get_partial_shape().to_shape()); + create_input(function->get_parameters()[5], + function->get_parameters()[5]->get_partial_shape().to_shape()); + create_input(function->get_parameters()[6], targetInputStaticShapes[1]); + create_input(function->get_parameters()[7], ov::Shape{targetInputStaticShapes[0][1]}); + } + void prepare() { + compile_model(); + inferRequest = compiledModel.create_infer_request(); + ASSERT_TRUE(inferRequest); + } + void reset() { + for (auto&& state : inferRequest.query_state()) { + state.reset(); + } + } + void run_test(std::shared_ptr model) { + function = model; + prepare(); + + auto core = ov::test::utils::PluginCache::get().core(); + ov::AnyMap configRef = {{"DISABLE_TRANSFORMATIONS" , "YES"}}; + auto compiledModelRef = core->compile_model(model, + ov::test::utils::DEVICE_TEMPLATE, configRef); + auto inferRequestRef = compiledModelRef.create_infer_request(); + + int idx = 0; + for (auto&& shapes : targetStaticShapes) { + generate(idx++, shapes); + for (const auto& input : inputs) { + inferRequest.set_tensor(input.first, input.second); + inferRequestRef.set_tensor(input.first, input.second); + } + inferRequest.infer(); + inferRequestRef.infer(); + auto logits = inferRequest.get_output_tensor(0); + auto logitsRef = inferRequestRef.get_output_tensor(0); + ov::test::utils::compare(logitsRef, logits, abs_threshold, rel_threshold); + } + reset(); + for (auto&& state : inferRequestRef.query_state()) { + state.reset(); + } + } +}; + +TEST_P(SdpaSinkTest, CompareWithRefs) { + SKIP_IF_CURRENT_TEST_IS_DISABLED(); + const auto& [inType, inputShapes] = this->GetParam(); + if (inType == ElementType::bf16 && !ov::with_cpu_x86_bfloat16()) + GTEST_SKIP(); + run_test(function); +} + +namespace { + +const std::vector inputShapes = { // greedy search + { + // q k v + {{-1, 1, 8, 64}, {{10, 1, 8, 64}, {1, 1, 8, 64}}}, + // pask kv + {{-1, 1, 8, 64}, {{0, 1, 8, 64}, {10, 1, 8, 64}}}, + // attention_mask + {{1, 8, -1, -1}, {{1, 8, 10, 10}, {1, 8, 1, 11}}}, + }}; + +INSTANTIATE_TEST_SUITE_P(smoke_SdpaSinkTest, + SdpaSinkTest, + ::testing::Combine(::testing::Values(ElementType::f32, ElementType::f16, ElementType::bf16), + ::testing::ValuesIn(inputShapes)), + SdpaSinkTest::getTestCaseName); +} // namespace +} // namespace test +} // namespace ov diff --git a/src/plugins/intel_cpu/tests/unit/shape_inference_test/custom_shape_infer/scaled_attn.cpp b/src/plugins/intel_cpu/tests/unit/shape_inference_test/custom_shape_infer/scaled_attn.cpp index aae3f91ea550ba..37e32461bd8bf8 100644 --- a/src/plugins/intel_cpu/tests/unit/shape_inference_test/custom_shape_infer/scaled_attn.cpp +++ b/src/plugins/intel_cpu/tests/unit/shape_inference_test/custom_shape_infer/scaled_attn.cpp @@ -252,6 +252,129 @@ INSTANTIATE_TEST_SUITE_P(CpuShapeInfer, SDPACpuShapeInferenceCorrectAttnMaskTest, ValuesIn(correctAttnmaskParams()), SDPACpuShapeInferenceCorrectAttnMaskTest::getTestCaseName); + +using SDPACpuShapeInferenceWrongSinkThrowExceptionTest = SDPACpuShapeInferenceTest; +TEST_P(SDPACpuShapeInferenceWrongSinkThrowExceptionTest, wrong_sink_input) { + ov::intel_cpu::ScaledDotProductAttentionWithKVCache::Config config; + config.permute_axes = permute_axes; + config.is_causal = causal; + const auto op = make_op(args, config); + std::ostringstream os; + os << "sink input do not match q and k,"; + auto set_input_shape_str = [&os](std::string name, const StaticShape & input_shape) { + os << name; + os << "("; + for (size_t i = 0; i < input_shape.size(); i++) { + os << input_shape[i]; + if (i < input_shape.size() - 1) { + os << "."; + } + } + os << ")"; + }; + set_input_shape_str(" query_dims:", input_shapes[0]); + set_input_shape_str(" cur_k_dims:", input_shapes[1]); + set_input_shape_str(" cur_v_dims:", input_shapes[2]); + set_input_shape_str(" attn_mask_dims:", input_shapes[3]); + set_input_shape_str(" scale_dims:", input_shapes[4]); + set_input_shape_str(" sink_dims:", input_shapes[5]); + set_input_shape_str(" beam_idx_dims:", input_shapes[6]); + set_input_shape_str(" cache_k_dims:", input_shapes[7]); + set_input_shape_str(" cache_v_dims:", input_shapes[8]); + OV_EXPECT_THROW(unit_test::cpu_test_shape_infer(op.get(), input_shapes, output_shapes), + ov::Exception, + HasSubstr(os.str())); +} + +auto wrongSinkParams = []() -> std::vector { + unit_test::ShapeVector attn_mask_vec = {{1, 16, 47, 47}, {47}, {1}, {1, 1, 1, 1}, {9, 47, 1}, {3, 1, 47, 1}, {1, 1, 1, 1, 1}}; + auto tuple = std::make_tuple( + unit_test::ShapeVector{{1, 16, 47, 56}, + {1, 8, 47, 56}, + {1, 8, 47, 56}, + {1, 1, 47, 94}, + {1}, + {1, 16, 1, 1}, + {1}, + {1, 8, 47, 56}, + {1, 8, 47, 56}}, + std::vector {}, + unit_test::ShapeVector{{1, 16, 47, 56}, {1, 8, 94, 56}, {1, 8, 94, 56}}, + false); + std::vector params; + auto createParams = [&attn_mask_vec, &tuple, ¶ms]() { + for (auto& item : attn_mask_vec) { + auto& input_shapes = std::get<0>(tuple); + input_shapes[5] = item; + params.push_back(tuple); + } + }; + createParams(); + attn_mask_vec = {{3, 16, 1, 1}, {1}, {32, 1, 1}, {3, 1, 1, 6}, {1, 1}, {3, 1, 1, 1}}; + tuple = make_tuple(unit_test::ShapeVector{{3, 1, 32, 128}, + {3, 1, 32, 128}, + {3, 1, 32, 128}, + {3, 1, 1, 5}, + {1}, + {3, 32, 1, 1}, + {3}, + {3, 4, 32, 128}, + {3, 4, 32, 128}}, + std::vector {0, 2, 1, 3}, + unit_test::ShapeVector{{3, 32, 1, 128}, {3, 5, 32, 128}, {3, 5, 32, 128}}, + false); + createParams(); + return params; +}; + +INSTANTIATE_TEST_SUITE_P(CpuShapeInfer, + SDPACpuShapeInferenceWrongSinkThrowExceptionTest, + ValuesIn(wrongSinkParams()), + SDPACpuShapeInferenceWrongSinkThrowExceptionTest::getTestCaseName); + +using SDPACpuShapeInferenceCorrectSinkTest = SDPACpuShapeInferenceTest; +TEST_P(SDPACpuShapeInferenceCorrectSinkTest, shape_inference) { + ov::intel_cpu::ScaledDotProductAttentionWithKVCache::Config config; + config.is_causal = causal; + config.permute_axes = permute_axes; + const auto op = make_op(args, config); + unit_test::cpu_test_shape_infer(op.get(), input_shapes, output_shapes); +} + +auto correctSinkParams = []() -> std::vector { + std::vector params; + auto tuple = std::make_tuple(unit_test::ShapeVector{{1, 16, 47, 56}, + {1, 8, 47, 56}, + {1, 8, 47, 56}, + {1, 1, 47, 94}, + {1}, + {1, 16, 1, 1}, + {1}, + {1, 8, 47, 56}, + {1, 8, 47, 56}}, + std::vector {}, + unit_test::ShapeVector{{1, 16, 47, 56}, {1, 8, 94, 56}, {1, 8, 94, 56}}, + false); + params.push_back(tuple); + tuple = make_tuple(unit_test::ShapeVector{{3, 1, 32, 128}, + {3, 1, 32, 128}, + {3, 1, 32, 128}, + {3, 1, 1, 5}, + {1}, + {3, 32, 1, 1}, + {3}, + {3, 4, 32, 128}, + {3, 4, 32, 128}}, + std::vector {0, 2, 1, 3}, + unit_test::ShapeVector{{3, 32, 1, 128}, {3, 5, 32, 128}, {3, 5, 32, 128}}, + false); + params.push_back(tuple); + return params; +}; +INSTANTIATE_TEST_SUITE_P(CpuShapeInfer, + SDPACpuShapeInferenceCorrectSinkTest, + ValuesIn(correctSinkParams()), + SDPACpuShapeInferenceCorrectSinkTest::getTestCaseName); } // namespace cpu_shape_infer } // namespace unit_test } // namespace intel_cpu