Skip to content

Commit f9f9f74

Browse files
authored
Merge pull request #315 from kvcache-ai/Atream-add-adapted
Atream add adapted
2 parents 1548c99 + 9239928 commit f9f9f74

File tree

2 files changed

+9
-3
lines changed

2 files changed

+9
-3
lines changed

ktransformers/operators/attention.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -262,7 +262,7 @@ def forward_linux(
262262
"""
263263

264264
# flash attn doesn't support head_dim bigger than 256
265-
# use vLLM triton attention kernel for MQA
265+
# use triton attention kernel adapted from vLLM and SGLang for MQA
266266
decode_attention_fwd_grouped(query_states, compressed_kv_with_k_pe, compressed_kv, attn_output,
267267
page_table,
268268
position_ids.squeeze(0).to(torch.int32), attn_logits,
@@ -551,4 +551,4 @@ def forward(
551551
if not output_attentions:
552552
attn_weights = None
553553

554-
return attn_output, attn_weights, past_key_value
554+
return attn_output, attn_weights, past_key_value

ktransformers/operators/triton_attention.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,9 @@
1+
# Adapted from
2+
# https://github.com/sgl-project/sglang/blob/9f635ea50de920aa507f486daafba26a5b837574/python/sglang/srt/layers/attention/triton_ops/decode_attention.py
3+
# which was originally adapted from
4+
# https://github.com/ModelTC/lightllm/blob/96353e868a840db4d103138caf15ed9dbea8c186/lightllm/models/deepseek2/triton_kernel/gqa_flash_decoding_stage1.py
5+
# https://github.com/ModelTC/lightllm/blob/96353e868a840db4d103138caf15ed9dbea8c186/lightllm/models/deepseek2/triton_kernel/gqa_flash_decoding_stage2.py
6+
17
import triton
28
import triton.language as tl
39

@@ -376,4 +382,4 @@ def decode_attention_fwd_grouped(
376382
)
377383

378384
_decode_softmax_reducev_fwd(attn_logits, q, o, v_buffer, b_seq_len,
379-
num_kv_splits)
385+
num_kv_splits)

0 commit comments

Comments
 (0)