@@ -26,15 +26,22 @@ namespace KERNEL_NAME {
2626
2727//extern "C" _GENX_MAIN_ void pa_multi_token(
2828extern "C" _GENX_MAIN_ void KERNEL_NAME(
29+ //query [q_len, num_heads, S]
2930 half* query [[type("svmptr_t")]],
30- half* key [[type("svmptr_t")]],
31- half* value [[type("svmptr_t")]],
31+ #if CMPA_KVCACHE_U8
32+ int8_t* k_cache [[type("svmptr_t")]],
33+ int8_t* v_cache [[type("svmptr_t")]],
34+ #else
35+ half* k_cache [[type("svmptr_t")]],
36+ half* v_cache [[type("svmptr_t")]],
37+ #endif
3238 int32_t* past_lens [[type("svmptr_t")]],
3339 int32_t* block_indices [[type("svmptr_t")]],
3440 int32_t* block_indices_begins [[type("svmptr_t")]],
3541 int32_t* subsequence_begins [[type("svmptr_t")]],
3642#if SPARSE_BLOCK_SIZE > 1
3743 bool* sparse_block_mask [[type("svmptr_t")]],
44+ bool* sparse_block_mask_wg [[type("svmptr_t")]],
3845#endif
3946 half* output [[type("svmptr_t")]],
4047 int q_len) {
@@ -44,16 +51,26 @@ extern "C" _GENX_MAIN_ void KERNEL_NAME(
4451 constexpr int num_kv_heads = CMFLA_NUM_KV_HEADS;
4552 constexpr int pa_block_sz = CMPA_BLOCK_SZ;
4653 //# query [q_len, num_heads, S]
47- //# key [kv_len, num_heads, S]
48- //# value [kv_len, num_heads, S]
49- //# sparse_block_mask [num_heads, q_blocks, kv_blocks]
54+ //# k_cache [kv_len, num_heads, S]
55+ //# v_cache [kv_len, num_heads, S]
56+ #if CMPA_KVCACHE_U8
57+ constexpr uint K_SLM_SIZE = (4*kv_step * head_size * sizeof(half));
58+ constexpr uint V_SLM_SIZE = (4*kv_step * head_size * sizeof(half));
59+ constexpr uint Q_SLM_SIZE = 0;//(q_step * head_size * sizeof(half)) * local_size;
60+
61+ cm_slm_init(K_SLM_SIZE + V_SLM_SIZE + Q_SLM_SIZE);
5062
63+ auto slm_K = cm_slm_alloc(K_SLM_SIZE);
64+ auto slm_V = cm_slm_alloc(V_SLM_SIZE);
65+
66+ #endif
5167 auto batch = cm_group_id(0);
5268 auto h = cm_group_id(1);
5369 auto hkv = h / (num_heads/num_kv_heads);
5470 auto wg_id = cm_group_id(2); // each work-group handles a sequence
5571 auto wg_local_id = cm_local_id(2);
5672 int local_size = cm_local_size(2);
73+
5774 int q_start_sg, kv_start, kv_seq_len, q_len_sg;
5875
5976 // multiple work-groups are required to split a sequence,
@@ -91,57 +108,71 @@ extern "C" _GENX_MAIN_ void KERNEL_NAME(
91108 ---------------------------------
92109 each grid can be [q_len_per_trunk, q_len_per_trunk].
93110 For each trunk, [q_len_per_trunk, past_q_lens] must be calculated. Such as: `20`,`21`. but for the 22,
94- causal mask optimization can be applied. different wgs would has different kv stop.
111+ casual mask optimization can be applied. differnt wgs would has different kv stop.
95112 //todo:kv_stop is wg level, should we change to sg level?
113+ sglevel would cause sgs in one wg diverge. so leave for now. also one wg has same kvstop makes eaiser for kv copying/loading into SLM/cache.
96114 */
97115 kv_stop = (wg_id + 1) * wg_seq_len + past_q_lens;
98116 if (kv_stop > kv_seq_len) kv_stop = kv_seq_len;
99117 }
100-
101- // printf("wg:%d.%d q: %d, +%d kv: %d, +%d, %d\n", wg_id, wg_local_id, q_start_sg, q_len_sg, kv_start, kv_seq_len, kv_stop);
102- // qkv fused
103- // constexpr uint num_total_heads = num_heads + num_kv_heads * 2;
104- // uint q_offset = (q_start*num_total_heads + h)*head_size;
105- // uint k_offset = (kv_start*num_total_heads + num_heads + hkv)*head_size;
106- // uint v_offset = (kv_start*num_total_heads + num_heads + num_kv_heads + hkv)*head_size;
118+ // printf("###########wg:%d.%d q: %d, +%d kv: %d, +%d, kvstop:%d\n", wg_id, wg_local_id, q_start_sg, q_len_sg, kv_start, kv_seq_len, kv_stop);
107119
108120 //Q/O[B, L, H, S]
109121 uint q_offset = (q_start_sg*num_heads + h)*head_size;
110- uint o_offset = (q_start_sg*num_heads + h)*head_size;
111-
112- //K/V[block_num, kv_heads, block_sz, head_sz]
113- uint k_offset = hkv*head_size*pa_block_sz;
114- uint v_offset = hkv*head_size*pa_block_sz;
115122
116123#if SPARSE_BLOCK_SIZE > 1
117124 //# sparse_block_mask [num_heads, q_blocks, kv_blocks]
118125 auto q_start_block = q_start_sg/ SPARSE_BLOCK_SIZE;
119126 int q_blocks = (q_len + SPARSE_BLOCK_SIZE - 1) / SPARSE_BLOCK_SIZE;
120127 int kv_blocks = (kv_seq_len + SPARSE_BLOCK_SIZE - 1) / SPARSE_BLOCK_SIZE;
128+ //[self.num_heads, q_block_num, kv_block_num]
121129 bool* block_mask_base = sparse_block_mask + (h * q_blocks + q_start_block)*kv_blocks;
130+ //[self.num_heads, wg_count_along_query, kv_block_num)]
131+ bool* wg_block_mask_base = sparse_block_mask_wg + (h * cm_group_count(2) + wg_id)*kv_blocks;
122132 // printf("wg:%d.%d q: %d, +%d kv: %d, +%d, %d, x-attn: %d, %dx%d, %p, %p\n", wg_id, wg_local_id, q_start_sg, q_len_sg, kv_start, kv_seq_len, kv_stop, q_start_block, q_blocks, kv_blocks, sparse_block_mask, block_mask_base);
123133#endif
124134
125- #if USE_LSC == 1
126- pa_kernel_lsc_prefetch<is_causal, num_heads, num_kv_heads, head_size, 0, 16>(
127- wg_local_id,
128- q_start_sg, //q_start for SG,
129- kv_stop,
130- q_len_sg, //q_step,
131- kv_seq_len, //kv_len, not used for now
132- reinterpret_cast<svmptr_t>(query + q_offset),
133- reinterpret_cast<svmptr_t>(key + k_offset),
134- reinterpret_cast<svmptr_t>(value + v_offset),
135+ #if CMPA_KVCACHE_U8
136+ uint kv_offset = hkv*(head_size+4)*pa_block_sz;
137+ pa_lsc_u8<is_causal, num_heads, num_kv_heads, head_size, 0>(
138+ slm_K,
139+ slm_V,
140+ wg_local_id,
141+ local_size,
142+ q_start_sg, //q_start for SG,
143+ kv_stop,
144+ q_len_sg, //q_step,
145+ kv_seq_len, //kv_len,
146+ reinterpret_cast<svmptr_t>(query + q_offset),
147+ reinterpret_cast<svmptr_t>(k_cache + kv_offset),
148+ reinterpret_cast<svmptr_t>(v_cache + kv_offset),
135149#if SPARSE_BLOCK_SIZE > 1
136- reinterpret_cast<svmptr_t>(block_mask_base),
150+ reinterpret_cast<svmptr_t>(block_mask_base),
151+ reinterpret_cast<svmptr_t>(wg_block_mask_base),
152+
137153#endif
138- reinterpret_cast<svmptr_t>(output + o_offset ),
139- past_q_lens,
140- block_indices);
154+ reinterpret_cast<svmptr_t>(output + q_offset ),
155+ past_q_lens,
156+ block_indices);
141157#else
142- static_assert(0);
158+ uint kv_offset = hkv*head_size*pa_block_sz;
159+ pa_kernel_lsc_prefetch_f16<is_causal, num_heads, num_kv_heads, head_size, 0, 16>(
160+ wg_local_id,
161+ q_start_sg, //q_start for SG,
162+ kv_stop,
163+ q_len_sg, //q_step,
164+ kv_seq_len, //kv_len,
165+ reinterpret_cast<svmptr_t>(query + q_offset),
166+ reinterpret_cast<svmptr_t>(k_cache + kv_offset),
167+ reinterpret_cast<svmptr_t>(v_cache + kv_offset),
168+ #if SPARSE_BLOCK_SIZE > 1
169+ reinterpret_cast<svmptr_t>(block_mask_base),
170+ reinterpret_cast<svmptr_t>(wg_block_mask_base),
143171
172+ #endif
173+ reinterpret_cast<svmptr_t>(output + q_offset),
174+ past_q_lens,
175+ block_indices);
144176#endif
145177}
146-
147- } // NAMESPACE
178+ } // namespace KERNEL_NAME
0 commit comments