Skip to content

Commit 198e26e

Browse files
riverlijunjiececiliapeng2011
authored andcommitted
Support kv cache u8 precision
1 parent ffdf2f1 commit 198e26e

File tree

8 files changed

+574
-505
lines changed

8 files changed

+574
-505
lines changed

src/common/transformations/src/transformations/common_optimizations/convert_pagedattn_inputs.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -107,7 +107,7 @@ ov::pass::ConvertPagedAttnInputs::ConvertPagedAttnInputs(const KVCacheConfig& co
107107
value_cache->set_element_type(value_cache_precision);
108108
bool status = false;
109109
if (pa_op->get_rt_info().count("num_k_heads") && pa_op->get_rt_info().count("k_head_size") &&
110-
pa_op->get_rt_info().count("num_v_heads") && pa_op->get_rt_info().count("num_v_heads")) {
110+
pa_op->get_rt_info().count("num_v_heads") && pa_op->get_rt_info().count("v_head_size")) {
111111
const auto key_cache_shape = init_cache_shape(pa_op->get_rt_info()["num_k_heads"].as<size_t>(),
112112
pa_op->get_rt_info()["k_head_size"].as<size_t>(),
113113
m_config.keyCacheBlockSize,

src/plugins/intel_gpu/src/graph/impls/cm/include/cm_sdpa_common.hpp

Lines changed: 424 additions & 456 deletions
Large diffs are not rendered by default.

src/plugins/intel_gpu/src/graph/impls/cm/pa_kv_cache_update_ref.cm

Lines changed: 39 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -32,8 +32,13 @@ extern "C" _GENX_MAIN_ void KERNEL_NAME(
3232
const int32_t* block_indices [[type("svmptr_t")]],
3333
const int32_t* block_indices_begins [[type("svmptr_t")]],
3434
const int32_t* subsequence_begins [[type("svmptr_t")]],
35+
#if KV_CACHE_COMPRESSION_PER_TOKEN
36+
uint8_t* key_cache [[type("svmptr_t")]],
37+
uint8_t* value_cache [[type("svmptr_t")]],
38+
#else
3539
half* key_cache [[type("svmptr_t")]],
36-
half* value_cache [[type("svmptr_t")]],
40+
half* value_cache [[type("svmptr_t")]],
41+
#endif
3742
uint32_t key_pitch,
3843
uint32_t key_offset,
3944
uint32_t value_pitch,
@@ -84,14 +89,43 @@ extern "C" _GENX_MAIN_ void KERNEL_NAME(
8489

8590
const uint block_offset = block_indices_begins[subsequence_idx] + current_block_idx;
8691

92+
#if KV_CACHE_COMPRESSION_PER_TOKEN
93+
// Assume: K_HEAD_SIZE == K_HEAD_SIZE
94+
auto quantize_and_store = [&](vector<half, K_HEAD_SIZE> data, uchar* out, uint out_offset, uint token_pos) {
95+
uint scale_offset = out_offset + K_HEAD_SIZE * PAGED_ATTENTION_BLOCK_SIZE + token_pos * sizeof(half);
96+
half max_val = cm_reduced_max<half>(data);
97+
half min_val = cm_reduced_min<half>(data);
98+
half scale_val = half(0.0);
99+
half zp_val = half(0.0);
100+
if(max_val == min_val) {
101+
scale_val = half(0.0);
102+
zp_val = max_val;
103+
} else {
104+
scale_val = 255.0 / (max_val - min_val);
105+
zp_val = (0.0 - min_val) * scale_val;
106+
}
107+
vector<half, K_HEAD_SIZE> dequant_data = cm_mul<half>(data, scale_val) + zp_val;
108+
vector<uchar, K_HEAD_SIZE> data_u8 = cm_rnde<uchar, K_HEAD_SIZE>(dequant_data);
109+
cm_ptr_store<uint32_t, K_HEAD_SIZE / 4>((uint32_t*)(out + out_offset + token_pos * K_HEAD_SIZE), 0, data_u8.format<uint32_t>());
110+
half *out_scale_zp = (half*)(out + scale_offset);
111+
out_scale_zp[0] = (max_val - min_val) / 255.0;
112+
out_scale_zp[PAGED_ATTENTION_BLOCK_SIZE] = zp_val;
113+
};
114+
#endif
115+
87116
{
88117
uint block_k_base_offset = (block_indices[block_offset] * KV_HEADS_NUM + head_idx) * ADJUSTED_K_HEAD_SIZE * PAGED_ATTENTION_BLOCK_SIZE;
89118
uint key_out_offset = block_k_base_offset + token_start_pos * K_HEAD_SIZE;
90119
uint key_in_offset = token_idx * key_pitch + head_idx * K_HEAD_SIZE + key_offset;
91120

92121
vector<half, K_HEAD_SIZE> key_data;
93122
key_data.format<int>() = cm_ptr_load<int, K_HEAD_SIZE / 2>((int*)key, key_in_offset * (int)sizeof(half));
123+
124+
#if KV_CACHE_COMPRESSION_PER_TOKEN
125+
quantize_and_store(key_data, (uchar*)key_cache, block_k_base_offset, token_start_pos);
126+
#else
94127
cm_ptr_store<int, K_HEAD_SIZE / 2>((int*)key_cache, key_out_offset * (int)sizeof(half), key_data.format<int>());
128+
#endif
95129
}
96130
{
97131
uint block_v_base_offset = (block_indices[block_offset] * KV_HEADS_NUM + head_idx) * ADJUSTED_V_HEAD_SIZE * PAGED_ATTENTION_BLOCK_SIZE;
@@ -106,6 +140,10 @@ extern "C" _GENX_MAIN_ void KERNEL_NAME(
106140

107141
vector<half, V_HEAD_SIZE> value_data;
108142
value_data.format<int>() = cm_ptr_load<int, V_HEAD_SIZE / 2>((int*)value, value_in_offset * (int)sizeof(half));
143+
#if KV_CACHE_COMPRESSION_PER_TOKEN
144+
quantize_and_store(value_data, (uchar*)value_cache, block_v_base_offset, token_start_pos);
145+
#else
109146
cm_ptr_store<int, V_HEAD_SIZE / 2>((int*)value_cache, value_out_offset * (int)sizeof(half), value_data.format<int>());
147+
#endif
110148
}
111149
}

src/plugins/intel_gpu/src/graph/impls/cm/pa_multi_token.cm

Lines changed: 66 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -26,15 +26,22 @@ namespace KERNEL_NAME {
2626

2727
//extern "C" _GENX_MAIN_ void pa_multi_token(
2828
extern "C" _GENX_MAIN_ void KERNEL_NAME(
29+
//query [q_len, num_heads, S]
2930
half* query [[type("svmptr_t")]],
30-
half* key [[type("svmptr_t")]],
31-
half* value [[type("svmptr_t")]],
31+
#if CMPA_KVCACHE_U8
32+
int8_t* k_cache [[type("svmptr_t")]],
33+
int8_t* v_cache [[type("svmptr_t")]],
34+
#else
35+
half* k_cache [[type("svmptr_t")]],
36+
half* v_cache [[type("svmptr_t")]],
37+
#endif
3238
int32_t* past_lens [[type("svmptr_t")]],
3339
int32_t* block_indices [[type("svmptr_t")]],
3440
int32_t* block_indices_begins [[type("svmptr_t")]],
3541
int32_t* subsequence_begins [[type("svmptr_t")]],
3642
#if SPARSE_BLOCK_SIZE > 1
3743
bool* sparse_block_mask [[type("svmptr_t")]],
44+
bool* sparse_block_mask_wg [[type("svmptr_t")]],
3845
#endif
3946
half* output [[type("svmptr_t")]],
4047
int q_len) {
@@ -44,16 +51,26 @@ extern "C" _GENX_MAIN_ void KERNEL_NAME(
4451
constexpr int num_kv_heads = CMFLA_NUM_KV_HEADS;
4552
constexpr int pa_block_sz = CMPA_BLOCK_SZ;
4653
//# query [q_len, num_heads, S]
47-
//# key [kv_len, num_heads, S]
48-
//# value [kv_len, num_heads, S]
49-
//# sparse_block_mask [num_heads, q_blocks, kv_blocks]
54+
//# k_cache [kv_len, num_heads, S]
55+
//# v_cache [kv_len, num_heads, S]
56+
#if CMPA_KVCACHE_U8
57+
constexpr uint K_SLM_SIZE = (4*kv_step * head_size * sizeof(half));
58+
constexpr uint V_SLM_SIZE = (4*kv_step * head_size * sizeof(half));
59+
constexpr uint Q_SLM_SIZE = 0;//(q_step * head_size * sizeof(half)) * local_size;
60+
61+
cm_slm_init(K_SLM_SIZE + V_SLM_SIZE + Q_SLM_SIZE);
5062

63+
auto slm_K = cm_slm_alloc(K_SLM_SIZE);
64+
auto slm_V = cm_slm_alloc(V_SLM_SIZE);
65+
66+
#endif
5167
auto batch = cm_group_id(0);
5268
auto h = cm_group_id(1);
5369
auto hkv = h / (num_heads/num_kv_heads);
5470
auto wg_id = cm_group_id(2); // each work-group handles a sequence
5571
auto wg_local_id = cm_local_id(2);
5672
int local_size = cm_local_size(2);
73+
5774
int q_start_sg, kv_start, kv_seq_len, q_len_sg;
5875

5976
// multiple work-groups are required to split a sequence,
@@ -91,57 +108,71 @@ extern "C" _GENX_MAIN_ void KERNEL_NAME(
91108
---------------------------------
92109
each grid can be [q_len_per_trunk, q_len_per_trunk].
93110
For each trunk, [q_len_per_trunk, past_q_lens] must be calculated. Such as: `20`,`21`. but for the 22,
94-
causal mask optimization can be applied. different wgs would has different kv stop.
111+
casual mask optimization can be applied. differnt wgs would has different kv stop.
95112
//todo:kv_stop is wg level, should we change to sg level?
113+
sglevel would cause sgs in one wg diverge. so leave for now. also one wg has same kvstop makes eaiser for kv copying/loading into SLM/cache.
96114
*/
97115
kv_stop = (wg_id + 1) * wg_seq_len + past_q_lens;
98116
if (kv_stop > kv_seq_len) kv_stop = kv_seq_len;
99117
}
100-
101-
// printf("wg:%d.%d q: %d, +%d kv: %d, +%d, %d\n", wg_id, wg_local_id, q_start_sg, q_len_sg, kv_start, kv_seq_len, kv_stop);
102-
// qkv fused
103-
// constexpr uint num_total_heads = num_heads + num_kv_heads * 2;
104-
// uint q_offset = (q_start*num_total_heads + h)*head_size;
105-
// uint k_offset = (kv_start*num_total_heads + num_heads + hkv)*head_size;
106-
// uint v_offset = (kv_start*num_total_heads + num_heads + num_kv_heads + hkv)*head_size;
118+
// printf("###########wg:%d.%d q: %d, +%d kv: %d, +%d, kvstop:%d\n", wg_id, wg_local_id, q_start_sg, q_len_sg, kv_start, kv_seq_len, kv_stop);
107119

108120
//Q/O[B, L, H, S]
109121
uint q_offset = (q_start_sg*num_heads + h)*head_size;
110-
uint o_offset = (q_start_sg*num_heads + h)*head_size;
111-
112-
//K/V[block_num, kv_heads, block_sz, head_sz]
113-
uint k_offset = hkv*head_size*pa_block_sz;
114-
uint v_offset = hkv*head_size*pa_block_sz;
115122

116123
#if SPARSE_BLOCK_SIZE > 1
117124
//# sparse_block_mask [num_heads, q_blocks, kv_blocks]
118125
auto q_start_block = q_start_sg/ SPARSE_BLOCK_SIZE;
119126
int q_blocks = (q_len + SPARSE_BLOCK_SIZE - 1) / SPARSE_BLOCK_SIZE;
120127
int kv_blocks = (kv_seq_len + SPARSE_BLOCK_SIZE - 1) / SPARSE_BLOCK_SIZE;
128+
//[self.num_heads, q_block_num, kv_block_num]
121129
bool* block_mask_base = sparse_block_mask + (h * q_blocks + q_start_block)*kv_blocks;
130+
//[self.num_heads, wg_count_along_query, kv_block_num)]
131+
bool* wg_block_mask_base = sparse_block_mask_wg + (h * cm_group_count(2) + wg_id)*kv_blocks;
122132
// printf("wg:%d.%d q: %d, +%d kv: %d, +%d, %d, x-attn: %d, %dx%d, %p, %p\n", wg_id, wg_local_id, q_start_sg, q_len_sg, kv_start, kv_seq_len, kv_stop, q_start_block, q_blocks, kv_blocks, sparse_block_mask, block_mask_base);
123133
#endif
124134

125-
#if USE_LSC == 1
126-
pa_kernel_lsc_prefetch<is_causal, num_heads, num_kv_heads, head_size, 0, 16>(
127-
wg_local_id,
128-
q_start_sg, //q_start for SG,
129-
kv_stop,
130-
q_len_sg, //q_step,
131-
kv_seq_len, //kv_len, not used for now
132-
reinterpret_cast<svmptr_t>(query + q_offset),
133-
reinterpret_cast<svmptr_t>(key + k_offset),
134-
reinterpret_cast<svmptr_t>(value + v_offset),
135+
#if CMPA_KVCACHE_U8
136+
uint kv_offset = hkv*(head_size+4)*pa_block_sz;
137+
pa_lsc_u8<is_causal, num_heads, num_kv_heads, head_size, 0>(
138+
slm_K,
139+
slm_V,
140+
wg_local_id,
141+
local_size,
142+
q_start_sg, //q_start for SG,
143+
kv_stop,
144+
q_len_sg, //q_step,
145+
kv_seq_len, //kv_len,
146+
reinterpret_cast<svmptr_t>(query + q_offset),
147+
reinterpret_cast<svmptr_t>(k_cache + kv_offset),
148+
reinterpret_cast<svmptr_t>(v_cache + kv_offset),
135149
#if SPARSE_BLOCK_SIZE > 1
136-
reinterpret_cast<svmptr_t>(block_mask_base),
150+
reinterpret_cast<svmptr_t>(block_mask_base),
151+
reinterpret_cast<svmptr_t>(wg_block_mask_base),
152+
137153
#endif
138-
reinterpret_cast<svmptr_t>(output + o_offset),
139-
past_q_lens,
140-
block_indices);
154+
reinterpret_cast<svmptr_t>(output + q_offset),
155+
past_q_lens,
156+
block_indices);
141157
#else
142-
static_assert(0);
158+
uint kv_offset = hkv*head_size*pa_block_sz;
159+
pa_kernel_lsc_prefetch_f16<is_causal, num_heads, num_kv_heads, head_size, 0, 16>(
160+
wg_local_id,
161+
q_start_sg, //q_start for SG,
162+
kv_stop,
163+
q_len_sg, //q_step,
164+
kv_seq_len, //kv_len,
165+
reinterpret_cast<svmptr_t>(query + q_offset),
166+
reinterpret_cast<svmptr_t>(k_cache + kv_offset),
167+
reinterpret_cast<svmptr_t>(v_cache + kv_offset),
168+
#if SPARSE_BLOCK_SIZE > 1
169+
reinterpret_cast<svmptr_t>(block_mask_base),
170+
reinterpret_cast<svmptr_t>(wg_block_mask_base),
143171

172+
#endif
173+
reinterpret_cast<svmptr_t>(output + q_offset),
174+
past_q_lens,
175+
block_indices);
144176
#endif
145177
}
146-
147-
} // NAMESPACE
178+
} // namespace KERNEL_NAME

src/plugins/intel_gpu/src/graph/impls/cm/pa_single_token.cm

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -45,9 +45,6 @@
4545

4646
#define KV_PARTITION_STEP_NUM (KV_PARTITION_SIZE / KV_STEP)
4747

48-
#define KV_SCALE_ZP_SIZE 0 // 4: scale/zp size
49-
50-
5148
#define DEBUG_ENABLE 0
5249
#if DEBUG_ENABLE
5350
template<typename T, int M, int N>
@@ -103,7 +100,7 @@ void show(vector<T, N> vec) {
103100

104101
//prepack [K, N] to [K/2, N, 2] layout.
105102
template <typename T1, typename T2, int K, int N>
106-
inline void prepackAsVNNIWidth2(matrix_ref<T1, K, N> input, matrix_ref<T2, K/2, N*2> out) {
103+
inline void prepack_to_VNNI_W2(matrix_ref<T1, K, N> input, matrix_ref<T2, K/2, N*2> out) {
107104
#pragma unroll
108105
for (int r = 0; r < K/2; r++) {
109106
out.row(r).select<N, 2>(0) = input.row(r*2);
@@ -498,7 +495,7 @@ extern "C" _GENX_MAIN_ void KERNEL_NAME(
498495
VmatNormal[r] = 0;
499496
}
500497
}
501-
prepackAsVNNIWidth2(VmatNormal, Vmat.format<half, REG_K/2, REG_N*2>());
498+
prepack_to_VNNI_W2(VmatNormal, Vmat.format<half, REG_K/2, REG_N*2>());
502499
#else
503500
cm_load<lsc::VNNI>(Vmat[0].format<half>(), b2dV.set_block_y(kv_pos));
504501
#endif

src/plugins/intel_gpu/src/graph/impls/cm/paged_attention.hpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@ struct PagedAttentionImplementationManager : public ImplementationManager {
2727
};
2828
static constexpr std::array supported_kv_types = {
2929
ov::element::f16,
30-
// ov::element::i8,
30+
ov::element::i8,
3131
};
3232

3333
auto& engine = node.get_program().get_engine();

src/plugins/intel_gpu/src/graph/impls/cm/paged_attention_gen.cpp

Lines changed: 40 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -66,10 +66,25 @@ inline size_t get_kv_len(const RuntimeParams& params, const PagedAttentionStage&
6666
return 0; // Fallback case, should not be reached
6767
}
6868

69+
inline size_t get_input_kv_len(const RuntimeParams& params) {
70+
auto key_shape = params.input_layouts[PagedAttentionInputIdx::KEY].get_shape();
71+
const size_t kv_len = key_shape[key_shape.size() - 2];
72+
return kv_len;
73+
}
74+
6975
inline size_t get_aligned_kv_len(const size_t kv_len) {
7076
return (kv_len + PA_KV_CACHE_BLOCK_SIZE - 1) / PA_KV_CACHE_BLOCK_SIZE * PA_KV_CACHE_BLOCK_SIZE;
7177
}
7278

79+
inline bool get_kv_compressed(const RuntimeParams& params) {
80+
auto key_cache_layout = params.input_layouts[PagedAttentionInputIdx::KEY_CACHE];
81+
if (data_type_traits::is_i8_u8(key_cache_layout.data_type)) {
82+
return true;
83+
} else {
84+
return false;
85+
}
86+
}
87+
7388
int64_t get_aligned_seq_len(const kernel_impl_params& impl_param, const PagedAttentionStage& stage, int64_t target_seq_len_block_size = 16) {
7489
// Since at prefill stage Q, K, V inputs may contain multiple sequences with arbitrary
7590
// target sequence lengths each (shape is [sequences_num * target_seq_len, num_heads * head_size]),
@@ -268,10 +283,18 @@ JitConstants PagedAttentionGeneratorKVCacheUpdate::get_jit_constants(const kerne
268283
jit.make("KV_HEADS_NUM", desc->kv_heads_num);
269284
jit.make("K_HEAD_SIZE", desc->k_head_size);
270285
jit.make("V_HEAD_SIZE", desc->v_head_size);
271-
jit.make("ADJUSTED_K_HEAD_SIZE", desc->k_head_size);
272-
jit.make("ADJUSTED_V_HEAD_SIZE", desc->v_head_size);
273286
jit.make("PAGED_ATTENTION_BLOCK_SIZE", PA_KV_CACHE_BLOCK_SIZE);
274287

288+
if (get_kv_compressed(params)) {
289+
jit.make("KV_CACHE_COMPRESSION_PER_TOKEN", 1);
290+
jit.make("ADJUSTED_K_HEAD_SIZE", desc->k_head_size + 4);
291+
jit.make("ADJUSTED_V_HEAD_SIZE", desc->v_head_size + 4);
292+
} else {
293+
jit.make("KV_CACHE_COMPRESSION_PER_TOKEN", 0);
294+
jit.make("ADJUSTED_K_HEAD_SIZE", desc->k_head_size);
295+
jit.make("ADJUSTED_V_HEAD_SIZE", desc->v_head_size);
296+
}
297+
275298
return jit;
276299
}
277300

@@ -302,7 +325,8 @@ DispatchDataFunc PagedAttentionGeneratorKVCacheUpdate::get_dispatch_data_func()
302325
const auto desc = params.typed_desc<paged_attention>();
303326
// auto rtp = static_cast<PagedAttentionRuntimeParams*>(rt_params);
304327

305-
const size_t kv_len = get_max_context_len(params);
328+
// const size_t kv_len = get_max_context_len(params);
329+
const size_t kv_len = get_input_kv_len(params);
306330
const size_t kv_heads_num = desc->kv_heads_num;
307331
const size_t wg_count = (kv_len + WG_SIZE - 1) / WG_SIZE;
308332

@@ -372,7 +396,8 @@ DispatchDataFunc PagedAttentionGeneratorKVCacheUpdate::get_dispatch_data_func()
372396
if (DEBUG_ENABLED) { // Debug
373397
std::cout << "PagedAttentionGeneratorKVCacheUpdate::get_dispatch_data_func: "
374398
<< "kv_len: " << kv_len << ", key_pitch: " << key_pitch << ", key_offset: " << key_offset << ", value_pitch: " << value_pitch
375-
<< ", value_offset: " << value_offset << ", "<< std::endl;
399+
<< ", value_offset: " << value_offset << ", gws: [" << wgs.global[0] << ", " << wgs.global[1] << ", " << wgs.global[2] << "]"
400+
<< ", lws: [" << wgs.local[0] << ", " << wgs.local[1] << ", " << wgs.local[2] << "]" << std::endl;
376401
}
377402

378403
// TODO: support multiple sequences
@@ -429,6 +454,12 @@ JitConstants PagedAttentionGeneratorMultiToken::get_jit_constants(const kernel_i
429454
jit.make("CMPA_BLOCK_SZ", PA_KV_CACHE_BLOCK_SIZE);
430455
jit.make("SPARSE_BLOCK_SIZE", xattn_block_size);
431456
jit.make("Q_STEP", get_q_step(xe_arch, true));
457+
458+
if (get_kv_compressed(params)) {
459+
jit.make("CMPA_KVCACHE_U8", 1);
460+
} else {
461+
jit.make("CMPA_KVCACHE_U8", 0);
462+
}
432463
// for (auto& it : jit) {
433464
// std::cout << "\tjit[" << it.name << "] = " << it.value << std::endl;
434465
// }
@@ -509,7 +540,11 @@ JitConstants PagedAttentionGeneratorSingleToken::get_jit_constants(const kernel_
509540
jit.make("KV_HEADS_NUM", desc->kv_heads_num);
510541
jit.make("Q_STEP", get_q_step(xe_arch, true));
511542

512-
jit.make("KV_CACHE_COMPRESSION", 0);
543+
if (get_kv_compressed(params)) {
544+
jit.make("KV_CACHE_COMPRESSION", 1);
545+
} else {
546+
jit.make("KV_CACHE_COMPRESSION", 0);
547+
}
513548

514549
return jit;
515550
}

src/plugins/intel_gpu/src/graph/impls/ocl_v2/sdpa/paged_attention_opt.hpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@ struct PagedAttentionOpt : public ImplementationManager {
2626
};
2727
static constexpr std::array supported_kv_types = {
2828
#if ENABLE_PA_CM_PATH
29-
ov::element::i8,
29+
ov::element::f32,
3030
#else
3131
ov::element::f32,
3232
ov::element::f16,

0 commit comments

Comments
 (0)