diff --git a/README.md b/README.md
index d204fddff6..498b814d19 100644
--- a/README.md
+++ b/README.md
@@ -142,6 +142,8 @@ LMDeploy is a toolkit for compressing, deploying, and serving LLM, developed by
DeepSeek-MoE (16B)
DeepSeek-V2 (16B, 236B)
DeepSeek-V2.5 (236B)
+ DeepSeek-V3 (685B)
+ DeepSeek-V3.2 (685B)
Mixtral (8x7B, 8x22B)
Gemma (2B - 7B)
StarCoder2 (3B - 15B)
diff --git a/README_ja.md b/README_ja.md
index 75d05390ad..b693918c3a 100644
--- a/README_ja.md
+++ b/README_ja.md
@@ -129,6 +129,8 @@ LMDeploy TurboMindエンジンは卓越した推論能力を持ち、さまざ
DeepSeek-MoE (16B)
DeepSeek-V2 (16B, 236B)
DeepSeek-V2.5 (236B)
+ DeepSeek-V3 (685B)
+ DeepSeek-V3.2 (685B)
Mixtral (8x7B, 8x22B)
Gemma (2B - 7B)
StarCoder2 (3B - 15B)
diff --git a/README_zh-CN.md b/README_zh-CN.md
index f6f10a5b42..b1bc99f46e 100644
--- a/README_zh-CN.md
+++ b/README_zh-CN.md
@@ -143,6 +143,8 @@ LMDeploy TurboMind 引擎拥有卓越的推理能力,在各种规模的模型
DeepSeek-MoE (16B)
DeepSeek-V2 (16B, 236B)
DeepSeek-V2.5 (236B)
+ DeepSeek-V3 (685B)
+ DeepSeek-V3.2 (685B)
Mixtral (8x7B, 8x22B)
Gemma (2B - 7B)
StarCoder2 (3B - 15B)
diff --git a/docs/en/supported_models/supported_models.md b/docs/en/supported_models/supported_models.md
index aa28854d8a..966353b99e 100644
--- a/docs/en/supported_models/supported_models.md
+++ b/docs/en/supported_models/supported_models.md
@@ -90,6 +90,8 @@ The following tables detail the models supported by LMDeploy's TurboMind engine
| DeepSeek-MoE | 16B | LLM | Yes | No | No | No | No |
| DeepSeek-V2 | 16B, 236B | LLM | Yes | No | No | No | No |
| DeepSeek-V2.5 | 236B | LLM | Yes | No | No | No | No |
+| DeepSeek-V3 | 685B | LLM | Yes | No | No | No | No |
+| DeepSeek-V3.2 | 685B | LLM | Yes | No | No | No | No |
| DeepSeek-VL2 | 3B - 27B | MLLM | Yes | No | No | No | No |
| MiniCPM3 | 4B | LLM | Yes | Yes | Yes | No | No |
| MiniCPM-V-2_6 | 8B | LLM | Yes | No | No | No | Yes |
diff --git a/docs/zh_cn/supported_models/supported_models.md b/docs/zh_cn/supported_models/supported_models.md
index 8e9e3fef20..27bd84b331 100644
--- a/docs/zh_cn/supported_models/supported_models.md
+++ b/docs/zh_cn/supported_models/supported_models.md
@@ -90,6 +90,8 @@
| DeepSeek-MoE | 16B | LLM | Yes | No | No | No | No |
| DeepSeek-V2 | 16B, 236B | LLM | Yes | No | No | No | No |
| DeepSeek-V2.5 | 236B | LLM | Yes | No | No | No | No |
+| DeepSeek-V3 | 685B | LLM | Yes | No | No | No | No |
+| DeepSeek-V3.2 | 685B | LLM | Yes | No | No | No | No |
| DeepSeek-VL2 | 3B - 27B | MLLM | Yes | No | No | No | No |
| MiniCPM3 | 4B | LLM | Yes | Yes | Yes | No | No |
| MiniCPM-V-2_6 | 8B | LLM | Yes | No | No | No | Yes |
diff --git a/lmdeploy/pytorch/backends/attention.py b/lmdeploy/pytorch/backends/attention.py
index 0c842e8871..63941aebfb 100644
--- a/lmdeploy/pytorch/backends/attention.py
+++ b/lmdeploy/pytorch/backends/attention.py
@@ -15,6 +15,8 @@ class AttentionMetadata:
q_seqlens: torch.Tensor = None
kv_seqlens: torch.Tensor = None
fill_seqlens: torch.Tensor = None
+ cu_seqlens_q: torch.Tensor = None
+ cu_seqlens_k: torch.Tensor = None
quant_policy: Literal[0, 4, 8] = 0
@@ -70,6 +72,7 @@ def forward(
k_scales_zeros: torch.Tensor = None,
v_scales_zeros: torch.Tensor = None,
learnable_sink: torch.Tensor = None,
+ nsa_indices: torch.Tensor = None,
inplace: bool = False,
) -> torch.Tensor:
"""forward."""
diff --git a/lmdeploy/pytorch/backends/base.py b/lmdeploy/pytorch/backends/base.py
index fb3ff9ded7..937cdf01ee 100644
--- a/lmdeploy/pytorch/backends/base.py
+++ b/lmdeploy/pytorch/backends/base.py
@@ -31,6 +31,7 @@ class OpType(Enum):
FusedMoEW8A8 = auto()
LinearBlockedF8 = auto()
FusedMoEBlockedF8 = auto()
+ NSAIndexFP8 = auto()
class OpsBackend(ABC):
diff --git a/lmdeploy/pytorch/backends/cuda/attention.py b/lmdeploy/pytorch/backends/cuda/attention.py
index b241c384b2..69d8a47ea0 100644
--- a/lmdeploy/pytorch/backends/cuda/attention.py
+++ b/lmdeploy/pytorch/backends/cuda/attention.py
@@ -107,6 +107,7 @@ def forward(
v_scales_zeros: torch.Tensor = None,
learnable_sink: torch.Tensor = None,
inplace: bool = True,
+ **kwargs,
) -> torch.Tensor:
"""forward."""
block_offsets = attn_metadata.block_offsets
@@ -231,6 +232,69 @@ def use_fa3_warning():
return False
+def _try_dynamic_compile(func, *args, **kwargs):
+ """Try compile."""
+ try:
+ compiled_func = torch.compile(func, dynamic=True)
+ compiled_func(*args, **kwargs)
+ return compiled_func
+ except Exception:
+ return func
+
+
+class NSAIndicesUpdater:
+ """NSA indices updater.
+
+ Flash MLA sparse attention requires different indice format for prefill and decoding. This module is used to update
+ the indices to meet the requirements.
+ """
+
+ def __init__(self):
+ self._update_decode_func = None
+ self._update_prefill_func = None
+
+ def _update_decode_impl(self, nsa_indices: torch.Tensor, block_offsets: torch.Tensor,
+ block_size: int) -> torch.Tensor:
+ """Update for decode impl."""
+ block_ids = nsa_indices // block_size
+ block_ids = block_ids.clamp_min(0)
+ block_ids = block_offsets.gather(1, block_ids)
+ block_remain = nsa_indices % block_size
+ ret = block_ids * block_size + block_remain
+ ret[nsa_indices < 0] = -1
+ return ret[:, None]
+
+ def update_decode(self, nsa_indices: torch.Tensor, block_offsets: torch.Tensor, block_size: int) -> torch.Tensor:
+ """Update for decode."""
+ if self._update_decode_func is None:
+ self._update_decode_func = _try_dynamic_compile(self._update_decode_impl, nsa_indices, block_offsets,
+ block_size)
+
+ return self._update_decode_func(nsa_indices, block_offsets, block_size)
+
+ def _update_prefill_impl(self, nsa_indices: torch.Tensor, q_seqlens: torch.Tensor, cu_seqlens_k: torch.Tensor):
+ """Update for prefill impl."""
+ num_tokens = nsa_indices.size(0)
+ repeat_cu_seqlens_k = torch.repeat_interleave(cu_seqlens_k[:-1], q_seqlens, output_size=num_tokens)
+ neg_mask = nsa_indices < 0
+ nsa_indices = nsa_indices + repeat_cu_seqlens_k[:, None]
+ nsa_indices[neg_mask] = -1
+ return nsa_indices[:, None]
+
+ def update_prefill(self, nsa_indices: torch.Tensor, q_seqlens: torch.Tensor, cu_seqlens_k: torch.Tensor):
+ """Update for prefill."""
+ if self._update_prefill_func is None:
+ self._update_prefill_func = _try_dynamic_compile(self._update_prefill_impl, nsa_indices, q_seqlens,
+ cu_seqlens_k)
+
+ return self._update_prefill_func(nsa_indices, q_seqlens, cu_seqlens_k)
+
+ @staticmethod
+ @functools.lru_cache(maxsize=None)
+ def build():
+ return NSAIndicesUpdater()
+
+
class FlashMLAImpl(TritonAttentionImpl):
def __init__(
@@ -262,47 +326,264 @@ def __init__(
**kwargs,
)
- from lmdeploy.pytorch.kernels.cuda import flash_mla_fwd
- self.flash_mla_fwd = flash_mla_fwd
+ import flash_mla
+
+ from lmdeploy.pytorch.kernels.cuda.fill_kv_cache import fill_kv_cache_blocked_fp8
+ from lmdeploy.pytorch.kernels.cuda.flatten_kv_cache import flatten_kv_cache_mla_fp8
+ self.flash_mla_with_kvcache = flash_mla.flash_mla_with_kvcache
+ self.flash_mla_sparse_fwd = None
+ self.fill_kv_cache_blocked_fp8 = fill_kv_cache_blocked_fp8
+ self.flatten_kv_cache_mla_fp8 = flatten_kv_cache_mla_fp8
assert num_kv_heads == 1, 'MLA requires num kv heads equal to 1'
- use_fa3_warning()
- def forward(
+ self.nsa_updater = NSAIndicesUpdater.build()
+
+ def _get_flash_mla_sparse_fwd(self):
+ if self.flash_mla_sparse_fwd is not None:
+ return self.flash_mla_sparse_fwd
+
+ try:
+ import flash_mla
+ self.flash_mla_sparse_fwd = flash_mla.flash_mla_sparse_fwd
+ return self.flash_mla_sparse_fwd
+ except Exception:
+ logger.exception('Can not import flash_mla_sparse_fwd from flash_mla.')
+
+ def flash_mla_decoding(
self,
query: torch.Tensor,
- key: torch.Tensor,
- value: torch.Tensor,
k_cache: torch.Tensor,
- v_cache: torch.Tensor,
+ nsa_indices: torch.Tensor,
attn_metadata: TritonAttentionMetadata,
- k_scales_zeros: torch.Tensor = None,
- v_scales_zeros: torch.Tensor = None,
- learnable_sink: torch.Tensor = None,
- inplace: bool = True,
- ) -> torch.Tensor:
- """forward."""
+ ):
+ """Flash mla decoding."""
+ causal = self.causal
+ kv_seqlens = attn_metadata.kv_seqlens
+ block_offsets = attn_metadata.block_offsets
+ is_fp8_kvcache = k_cache.dtype == torch.float8_e4m3fn
+
+ query = query.unsqueeze(1)
+ if kv_seqlens.dtype == torch.int64:
+ kv_seqlens = kv_seqlens.to(torch.int32)
+
+ # update nsa indice according to flash-mla requirement
+ if nsa_indices is not None:
+ block_size = k_cache.size(1)
+ nsa_indices = self.nsa_updater.update_decode(nsa_indices, block_offsets, block_size)
+ causal = False
+
+ attn_output, _ = self.flash_mla_with_kvcache(query,
+ k_cache=k_cache,
+ block_table=block_offsets,
+ cache_seqlens=kv_seqlens,
+ head_dim_v=self.v_head_size,
+ softmax_scale=self.scale,
+ tile_scheduler_metadata=attn_metadata.tile_scheduler_metadata,
+ num_splits=attn_metadata.num_splits,
+ causal=causal,
+ is_fp8_kvcache=is_fp8_kvcache,
+ indices=nsa_indices)
+ attn_output = attn_output.squeeze(1)
+ return attn_output
+
+ def flash_mla_prefill(self, query: torch.Tensor, flatten_k: torch.Tensor, nsa_indices: torch.Tensor,
+ attn_metadata: TritonAttentionMetadata) -> torch.Tensor:
+ """Flash mla prefill, only used in sparse attention."""
+ q_seqlens = attn_metadata.q_seqlens
+ flash_mla_sparse_fwd = self._get_flash_mla_sparse_fwd()
+
+ num_q_heads = query.size(1)
+ # flash_mla_sparse_fwd requires query heads to be multiple of 64
+ if num_q_heads % 64 != 0:
+ query = torch.nn.functional.pad(query, (0, 0, 0, 64 - num_q_heads % 64))
+
+ nsa_indices = self.nsa_updater.update_prefill(nsa_indices, q_seqlens, attn_metadata.cu_seqlens_k)
+ output = flash_mla_sparse_fwd(
+ query,
+ flatten_k,
+ nsa_indices,
+ sm_scale=self.scale,
+ )
+ attn_output = output[0]
+ attn_output = attn_output[:, :num_q_heads]
+ return attn_output
+ def flash_attn_triton(
+ self,
+ query: torch.Tensor,
+ flatten_k: torch.Tensor,
+ flatten_v: torch.Tensor,
+ attn_metadata: TritonAttentionMetadata,
+ ):
+ """Triton flash attention, used if flash-attn is not available."""
+ q_start_loc = attn_metadata.q_start_loc
+ q_seqlens = attn_metadata.q_seqlens
+ kv_start_loc = attn_metadata.kv_start_loc
+ kv_seqlens = attn_metadata.kv_seqlens
+ max_q_seqlen = query.numel() // (query.size(-1) * query.size(-2))
+
+ q_shape = query.shape
+ o_shape = q_shape[:-1] + (self.v_head_size, )
+ attn_output = query.new_empty(o_shape)
+ self.flash_attention_fwd(
+ query,
+ flatten_k,
+ flatten_v,
+ attn_output,
+ q_start_loc=q_start_loc,
+ q_seqlens=q_seqlens,
+ kv_start_loc=kv_start_loc,
+ kv_seqlens=kv_seqlens,
+ max_seqlen=max_q_seqlen,
+ window_size=self.sliding_window,
+ sm_scale=self.scale,
+ logit_softcapping=self.logit_softcapping,
+ causal=self.causal,
+ )
+
+ return attn_output
+
+ def flash_attn_fa3(
+ self,
+ query: torch.Tensor,
+ flatten_k: torch.Tensor,
+ attn_metadata: TritonAttentionMetadata,
+ ):
+ """Flash attention 3, used if flash-attn 3 is available."""
+ max_q_seqlen = query.numel() // (query.size(-1) * query.size(-2))
+ kv_flatten_size = attn_metadata.kv_flatten_size
+ causal = self.causal
+ q_rope = query[:, :, self.v_head_size:]
+ q_nope = query[:, :, :self.v_head_size]
+ k_rope = flatten_k.view(kv_flatten_size, self.num_kv_heads, -1)[:, :, self.v_head_size:]
+ c_kv = flatten_k.view(kv_flatten_size, self.num_kv_heads, -1)[:, :, :self.v_head_size]
+ from lmdeploy.pytorch.third_party.flash_attn_interface import flash_attn_varlen_func
+ attn_output = flash_attn_varlen_func(
+ q=q_rope,
+ k=k_rope,
+ v=c_kv,
+ qv=q_nope,
+ cu_seqlens_q=attn_metadata.cu_seqlens_q,
+ cu_seqlens_k=attn_metadata.cu_seqlens_k,
+ max_seqlen_q=max_q_seqlen,
+ max_seqlen_k=kv_flatten_size,
+ softmax_scale=self.scale,
+ causal=causal,
+ window_size=(-1, -1) if self.sliding_window is None else self.sliding_window,
+ softcap=-1.0 if self.logit_softcapping is None else self.logit_softcapping,
+ )
+ return attn_output
+
+ def run_flatten_kv_cache(self,
+ k_cache: torch.Tensor,
+ v_cache: torch.Tensor,
+ attn_metadata: TritonAttentionMetadata,
+ out_dtype: torch.dtype,
+ is_nsa: bool,
+ k_scales_zeros: torch.Tensor = None,
+ v_scales_zeros: torch.Tensor = None):
+ """Flatten kv cache for prefill."""
+
+ kv_start_loc = attn_metadata.kv_start_loc
+ kv_seqlens = attn_metadata.kv_seqlens
+ block_offsets = attn_metadata.block_offsets
+ kv_flatten_size = attn_metadata.kv_flatten_size
+ quant_policy = attn_metadata.quant_policy
+ is_fp8_kvcache = k_cache.dtype == torch.float8_e4m3fn
+ BLOCK_BS = k_cache.size(1)
+
+ # pad one more block to avoid invalid kv visit
+ out_size = (_cdiv(kv_flatten_size, BLOCK_BS) * BLOCK_BS + BLOCK_BS)
+ flatten_kv_layout = 'shd' if use_fa3 or is_nsa else 'hsd'
+ if is_fp8_kvcache:
+ flatten_k = self.flatten_kv_cache_mla_fp8(
+ k_cache,
+ kv_seqlens,
+ block_offsets,
+ start_loc=kv_start_loc,
+ out_size=out_size,
+ out_dtype=out_dtype,
+ flatten_kv_layout=flatten_kv_layout,
+ )
+ flatten_v = flatten_k[..., :512]
+ else:
+ flatten_k, flatten_v = self.flatten_kv_cache(
+ k_cache,
+ v_cache,
+ kv_seqlens,
+ block_offsets,
+ start_loc=kv_start_loc,
+ out_size=kv_flatten_size if use_fa3 else out_size,
+ out_dtype=out_dtype,
+ k_scales_zeros=k_scales_zeros,
+ v_scales_zeros=v_scales_zeros,
+ quant_policy=quant_policy,
+ flatten_kv_layout=flatten_kv_layout,
+ )
+
+ return flatten_k, flatten_v
+
+ def run_fill_kv_cache(self,
+ query: torch.Tensor,
+ key: torch.Tensor,
+ value: torch.Tensor,
+ k_cache: torch.Tensor,
+ v_cache: torch.Tensor,
+ attn_metadata: TritonAttentionMetadata,
+ k_scales_zeros: torch.Tensor = None,
+ v_scales_zeros: torch.Tensor = None):
+ """Fill kv cache."""
block_offsets = attn_metadata.block_offsets
q_start_loc = attn_metadata.q_start_loc
fill_q_start_loc = q_start_loc
q_seqlens = attn_metadata.q_seqlens
fill_seqlens = q_seqlens
- kv_start_loc = attn_metadata.kv_start_loc
kv_seqlens = attn_metadata.kv_seqlens
- kv_flatten_size = attn_metadata.kv_flatten_size
quant_policy = attn_metadata.quant_policy
+
+ # max_q_seqlen
if attn_metadata.is_decoding:
max_q_seqlen = 1
else:
max_q_seqlen = query.numel() // (query.size(-1) * query.size(-2))
+
+ # fill_max_q_seqlen
fill_max_q_seqlen = max_q_seqlen
if attn_metadata.fill_seqlens is not None:
fill_seqlens = attn_metadata.fill_seqlens
fill_max_q_seqlen = key.numel() // (key.size(-1) * key.size(-2))
fill_q_start_loc = fill_seqlens.cumsum(0) - fill_seqlens
- # fill kv cache
- if key is not None and value is not None:
+ is_fp8_kvcache = k_cache.dtype == torch.float8_e4m3fn
+ if is_fp8_kvcache:
+ k_cache_scale = k_cache[..., 512:512 + 16].view(torch.float32)
+ k_cache_nope = k_cache[..., :512]
+ k_cache_pe = k_cache[..., 512 + 16:].view(key.dtype)
+ self.fill_kv_cache_blocked_fp8(
+ key[..., :512],
+ None,
+ k_cache_nope,
+ None,
+ k_cache_scale,
+ None,
+ cu_seqlen_q=attn_metadata.cu_seqlens_q,
+ kv_seqlens=attn_metadata.kv_seqlens,
+ max_q_seqlen=max_q_seqlen,
+ block_offsets=block_offsets,
+ group_size=128,
+ )
+ self.fill_kv_cache(
+ key[..., 512:],
+ None,
+ k_cache_pe,
+ None,
+ fill_q_start_loc,
+ fill_seqlens,
+ kv_seq_length=kv_seqlens,
+ max_q_seq_length=fill_max_q_seqlen,
+ block_offsets=block_offsets,
+ )
+ else:
self.fill_kv_cache(
key,
value,
@@ -318,77 +599,57 @@ def forward(
quant_policy=quant_policy,
)
- q_shape = query.shape
- o_shape = q_shape[:-1] + (self.v_head_size, )
+ def forward(
+ self,
+ query: torch.Tensor,
+ key: torch.Tensor,
+ value: torch.Tensor,
+ k_cache: torch.Tensor,
+ v_cache: torch.Tensor,
+ attn_metadata: TritonAttentionMetadata,
+ k_scales_zeros: torch.Tensor = None,
+ v_scales_zeros: torch.Tensor = None,
+ nsa_indices: torch.Tensor = None,
+ **kwargs,
+ ) -> torch.Tensor:
+ """forward."""
+
+ # check nsa
+ is_fp8_kvcache = k_cache.dtype == torch.float8_e4m3fn
+ is_nsa = nsa_indices is not None
+ if is_nsa:
+ assert is_fp8_kvcache
+
+ # fill kv cache
+ self.run_fill_kv_cache(
+ query,
+ key,
+ value,
+ k_cache,
+ v_cache,
+ attn_metadata,
+ k_scales_zeros=k_scales_zeros,
+ v_scales_zeros=v_scales_zeros,
+ )
is_decoding = attn_metadata.is_decoding
if is_decoding:
- query = query.unsqueeze(1)
- if kv_seqlens.dtype == torch.int64:
- kv_seqlens = kv_seqlens.to(torch.int32)
- attn_output = self.flash_mla_fwd(query,
- k_cache=k_cache,
- block_table=block_offsets,
- cache_seqlens=kv_seqlens,
- head_dim_v=self.v_head_size,
- softmax_scale=self.scale,
- tile_scheduler_metadata=attn_metadata.tile_scheduler_metadata,
- num_splits=attn_metadata.num_splits,
- causal=True)
+ attn_output = self.flash_mla_decoding(query, k_cache, nsa_indices, attn_metadata)
else:
- BLOCK_BS = k_cache.size(1)
- # pad one more block to avoid invalid kv visit
- out_size = (_cdiv(kv_flatten_size, BLOCK_BS) * BLOCK_BS + BLOCK_BS)
- flatten_k, flatten_v = self.flatten_kv_cache(
- k_cache,
- v_cache,
- kv_seqlens,
- block_offsets,
- start_loc=kv_start_loc,
- out_size=kv_flatten_size if use_fa3 else out_size,
- out_dtype=query.dtype,
- k_scales_zeros=k_scales_zeros,
- v_scales_zeros=v_scales_zeros,
- quant_policy=quant_policy,
- flatten_kv_layout='shd' if use_fa3 else 'hsd',
- )
- if use_fa3:
- q_rope = query[:, :, self.v_head_size:]
- q_nope = query[:, :, :self.v_head_size]
- k_rope = flatten_k.view(kv_flatten_size, self.num_kv_heads, -1)[:, :, self.v_head_size:]
- c_kv = flatten_k.view(kv_flatten_size, self.num_kv_heads, -1)[:, :, :self.v_head_size]
- from lmdeploy.pytorch.third_party.flash_attn_interface import flash_attn_varlen_func
- attn_output = flash_attn_varlen_func(
- q=q_rope,
- k=k_rope,
- v=c_kv,
- qv=q_nope,
- cu_seqlens_q=attn_metadata.cu_seqlens_q,
- cu_seqlens_k=attn_metadata.cu_seqlens_k,
- max_seqlen_q=max_q_seqlen,
- max_seqlen_k=kv_flatten_size,
- softmax_scale=self.scale,
- causal=self.causal,
- window_size=(-1, -1) if self.sliding_window is None else self.sliding_window,
- softcap=-1.0 if self.logit_softcapping is None else self.logit_softcapping,
- )
+ flatten_k, flatten_v = self.run_flatten_kv_cache(k_cache,
+ v_cache,
+ attn_metadata,
+ out_dtype=query.dtype,
+ is_nsa=nsa_indices is not None,
+ k_scales_zeros=k_scales_zeros,
+ v_scales_zeros=v_scales_zeros)
+ if is_nsa:
+ attn_output = self.flash_mla_prefill(query, flatten_k, nsa_indices, attn_metadata)
+ elif use_fa3:
+ attn_output = self.flash_attn_fa3(query, flatten_k, attn_metadata)
else:
- attn_output = query.new_empty(o_shape)
- self.flash_attention_fwd(
- query,
- flatten_k,
- flatten_v,
- attn_output,
- q_start_loc=q_start_loc,
- q_seqlens=q_seqlens,
- kv_start_loc=kv_start_loc,
- kv_seqlens=kv_seqlens,
- max_seqlen=max_q_seqlen,
- window_size=self.sliding_window,
- sm_scale=self.scale,
- logit_softcapping=self.logit_softcapping,
- causal=self.causal,
- )
+ attn_output = self.flash_attn_triton(query, flatten_k, flatten_v, attn_metadata)
+
return attn_output
diff --git a/lmdeploy/pytorch/backends/cuda/graph_runner.py b/lmdeploy/pytorch/backends/cuda/graph_runner.py
index deb6c66bfd..947ffe7646 100644
--- a/lmdeploy/pytorch/backends/cuda/graph_runner.py
+++ b/lmdeploy/pytorch/backends/cuda/graph_runner.py
@@ -82,6 +82,9 @@ def __init__(
input_buffers=dict(),
output_buffers=dict(),
vocab_size=self.model_config.vocab_size,
+ use_mla_fp8_cache=getattr(self.model_config, 'use_mla_fp8_cache', False),
+ use_flash_mla=getattr(self.model_config, 'use_flash_mla', False),
+ mla_index_topk=getattr(self.model_config, 'mla_index_topk', None),
)
self.device = device
self.max_batches = max_batches
diff --git a/lmdeploy/pytorch/backends/cuda/moe.py b/lmdeploy/pytorch/backends/cuda/moe.py
index a1394f28bd..b1c232f0e4 100644
--- a/lmdeploy/pytorch/backends/cuda/moe.py
+++ b/lmdeploy/pytorch/backends/cuda/moe.py
@@ -1,5 +1,6 @@
# Copyright (c) OpenMMLab. All rights reserved.
+import contextlib
from typing import Callable, List, Optional
import torch
@@ -575,6 +576,41 @@ def fusedmoe_forward(self, state, up_weight, up_scale, down_weight, down_scale):
return hidden_states
+@contextlib.contextmanager
+def mock_deep_gemm():
+ from dlblas.kernels.fused_moe_v3 import use_deep_gemm
+ if use_deep_gemm:
+ yield
+ return
+
+ # patch deep_gemm
+ import deep_gemm
+ import dlblas
+
+ from lmdeploy.pytorch.third_party import deep_gemm as patched_deep_gemm
+ func0_ = getattr(deep_gemm, 'get_col_major_tma_aligned_tensor', None)
+ func1_ = getattr(deep_gemm, 'm_grouped_gemm_fp8_fp8_bf16_nt_masked', None)
+ deep_gemm.get_col_major_tma_aligned_tensor = patched_deep_gemm.get_mn_major_tma_aligned_tensor
+ deep_gemm.m_grouped_gemm_fp8_fp8_bf16_nt_masked = patched_deep_gemm.m_grouped_fp8_gemm_nt_masked
+
+ # patch dlblas
+ dlblas.kernels.fused_moe_v3.use_deep_gemm = True
+ dlblas.kernels.fused_moe_v3.m_grouped_gemm_fp8_fp8_bf16_nt_contiguous = \
+ patched_deep_gemm.m_grouped_fp8_gemm_nt_contiguous
+ yield
+
+ # unpatch dlblas
+ dlblas.kernels.fused_moe_v3.use_deep_gemm = False
+
+ # unpatch deep_gemm
+ if func0_ is not None:
+ deep_gemm.get_col_major_tma_aligned_tensor = func0_
+ deep_gemm.m_grouped_gemm_fp8_fp8_bf16_nt_masked = func1_
+ else:
+ del deep_gemm.get_col_major_tma_aligned_tensor
+ del deep_gemm.m_grouped_gemm_fp8_fp8_bf16_nt_masked
+
+
class FusedDeepEpMoEBlockedF8Impl(TritonFusedMoEBlockedF8Impl):
def __init__(self,
@@ -646,16 +682,35 @@ def do_renormalize(self, topk_weights):
def fusedmoe_build(self, low_latency_mode: bool = False):
from dlblas.layers.moe.ep_moe import build_deepep_moe
- return build_deepep_moe(low_latency_mode,
- self.ep_size,
- self.ep_group,
- self.num_experts,
- self.hidden_dim,
- self.block_size,
- self.top_k,
- self.out_dtype,
- layer_idx=self.layer_idx,
- chunk_size=16 * 1024)
+ deepep_moe = build_deepep_moe(low_latency_mode,
+ self.ep_size,
+ self.ep_group,
+ self.num_experts,
+ self.hidden_dim,
+ self.block_size,
+ self.top_k,
+ self.out_dtype,
+ layer_idx=self.layer_idx,
+ chunk_size=16 * 1024)
+
+ # patch forward
+ _origin_forward = deepep_moe.forward
+ _origin_fusedmoe_forward = deepep_moe.fusedmoe_forward
+
+ def _patched_forward(*args, **kwargs):
+ with mock_deep_gemm():
+ out = _origin_forward(*args, **kwargs)
+ return out
+
+ def _patched_fusedmoe_forward(*args, **kwargs):
+ with mock_deep_gemm():
+ out = _origin_fusedmoe_forward(*args, **kwargs)
+ return out
+
+ deepep_moe.forward = _patched_forward
+ deepep_moe.fusedmoe_forward = _patched_fusedmoe_forward
+
+ return deepep_moe
class TritonFusedMoEBlockedF8Builder(FusedMoEBlockedF8Builder):
diff --git a/lmdeploy/pytorch/backends/cuda/nsa.py b/lmdeploy/pytorch/backends/cuda/nsa.py
new file mode 100644
index 0000000000..ed0c471a39
--- /dev/null
+++ b/lmdeploy/pytorch/backends/cuda/nsa.py
@@ -0,0 +1,68 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+from torch import Tensor
+
+from lmdeploy.pytorch.kernels.cuda.bitonic_topk import bitonic_topk
+from lmdeploy.pytorch.kernels.cuda.blocked_gemm_fp8 import quant_fp8
+from lmdeploy.pytorch.kernels.cuda.ds_index import fp8_index
+from lmdeploy.pytorch.kernels.cuda.fill_kv_cache import fill_kv_cache_blocked_fp8
+
+from ..nsa import BaseNSAIndexFP8, BaseNSAIndexFP8Builder, NSAIndexMeta
+
+
+class TritonNSAIndexFP8(BaseNSAIndexFP8):
+
+ def __init__(self, topk: int, softmax_scale: float, block_size: int, fill: int) -> None:
+ super().__init__()
+ self.topk = topk
+ self.softmax_scale = softmax_scale
+ self.block_size = block_size
+ self.fill = fill
+
+ def forward(self, q: Tensor, k: Tensor, weights: Tensor, k_cache: Tensor, k_s_cache: Tensor,
+ meta: NSAIndexMeta) -> Tensor:
+
+ assert q.dim() == 3
+ assert k.dim() == 2
+ cu_seqlen_q = meta.cu_seqlen_q
+ q_seqlens = meta.q_seqlens
+ k_seqlens = meta.k_seqlens
+ block_offset = meta.block_offset
+ max_q_seqlen = meta.max_q_seqlen
+ max_kv_seqlen = meta.max_kv_seqlen
+
+ q_shape = q.shape
+ q = q.reshape(-1, q_shape[-1])
+ q, q_s = quant_fp8(q, self.block_size, dtype=k_cache.dtype, trans_scale=True)
+ q = q.reshape(*q_shape)
+ q_s = q_s.reshape(weights.shape)
+ q_s = q_s * self.softmax_scale * weights
+
+ fill_kv_cache_blocked_fp8(k[:, None],
+ None,
+ k_cache[..., None, :],
+ None,
+ k_s_cache[..., None, :],
+ None,
+ cu_seqlen_q=cu_seqlen_q,
+ kv_seqlens=k_seqlens,
+ max_q_seqlen=max_q_seqlen,
+ block_offsets=block_offset,
+ group_size=self.block_size)
+
+ scores = fp8_index(q,
+ q_s,
+ k_cache,
+ k_s_cache[..., 0],
+ cu_seqlen_q,
+ k_seqlens,
+ block_offset,
+ max_q_seqlen=max_q_seqlen,
+ max_k_seqlen=max_kv_seqlen)
+ return bitonic_topk(scores, q_seqlens, k_seqlens, self.topk, fill=self.fill, descending=True)
+
+
+class TritonNSAIndexFP8Builder(BaseNSAIndexFP8Builder):
+
+ @staticmethod
+ def build(topk: int, softmax_scale: float, block_size: int = 128, fill: int = -1) -> BaseNSAIndexFP8:
+ return TritonNSAIndexFP8(topk, softmax_scale=softmax_scale, block_size=block_size, fill=fill)
diff --git a/lmdeploy/pytorch/backends/cuda/op_backend.py b/lmdeploy/pytorch/backends/cuda/op_backend.py
index d6b77de59e..94b6d507df 100644
--- a/lmdeploy/pytorch/backends/cuda/op_backend.py
+++ b/lmdeploy/pytorch/backends/cuda/op_backend.py
@@ -12,13 +12,6 @@
logger = get_logger('lmdeploy')
-def _get_meta_flashmla(kv_seqlens, num_attention_heads):
- """Get meta for flashmla."""
- import flash_mla
- tile_scheduler_metadata, num_splits = flash_mla.get_mla_metadata(kv_seqlens.to(torch.int32), num_attention_heads, 1)
- return tile_scheduler_metadata, num_splits
-
-
class CudaOpsBackend(DefaultOpsBackend):
"""Cuda layer backend."""
@@ -72,6 +65,9 @@ def get_layer_impl_builder(cls, layer_type: OpType):
elif layer_type == OpType.LinearBlockedF8:
from .blockedf8_modules import TritonLinearBlockedF8Builder
return TritonLinearBlockedF8Builder
+ elif layer_type == OpType.NSAIndexFP8:
+ from .nsa import TritonNSAIndexFP8Builder
+ return TritonNSAIndexFP8Builder
else:
logger.debug(f'Op {layer_type} fallback to default implementation.')
return super().get_layer_impl_builder(layer_type)
@@ -111,10 +107,19 @@ def get_v_block_shape(
)
@classmethod
- def update_meta_flashmla(cls, attn_metadata, num_attention_heads):
+ def update_meta_flashmla(cls, attn_metadata, model_config: ModelConfig):
"""Update meta for flashmla."""
- tile_scheduler_metadata, num_splits = _get_meta_flashmla(attn_metadata.kv_seqlens.to(torch.int32),
- num_attention_heads)
+ import flash_mla
+ num_attention_heads = model_config.num_attention_heads
+ is_fp8_kvcache = model_config.use_mla_fp8_cache
+ index_topk = model_config.mla_index_topk
+ num_heads_q = None if index_topk is None else num_attention_heads
+ tile_scheduler_metadata, num_splits = flash_mla.get_mla_metadata(attn_metadata.kv_seqlens.to(torch.int32),
+ num_attention_heads,
+ num_heads_k=1,
+ num_heads_q=num_heads_q,
+ is_fp8_kvcache=is_fp8_kvcache,
+ topk=index_topk)
attn_metadata.tile_scheduler_metadata = tile_scheduler_metadata
attn_metadata.num_splits = num_splits
@@ -149,7 +154,8 @@ def update_step_context(cls, step_context):
)
if getattr(step_context.model_config, 'use_flash_mla', False) is True:
if step_context.is_decoding is True:
- cls.update_meta_flashmla(attn_metadata, step_context.model_config.num_attention_heads)
+ model_config = step_context.model_config
+ cls.update_meta_flashmla(attn_metadata, model_config)
cross_seqlens = step_context.cross_seqlens
cross_kv_seqlens = step_context.cross_kv_seqlens
diff --git a/lmdeploy/pytorch/backends/dlinfer/attention.py b/lmdeploy/pytorch/backends/dlinfer/attention.py
index 228e4de1c4..db92aa39f7 100644
--- a/lmdeploy/pytorch/backends/dlinfer/attention.py
+++ b/lmdeploy/pytorch/backends/dlinfer/attention.py
@@ -66,6 +66,7 @@ def forward(
k_scales_zeros: Tensor = None,
v_scales_zeros: Tensor = None,
learnable_sink: Tensor = None,
+ nsa_indices: Tensor = None,
inplace: bool = True,
) -> Tensor:
"""forward."""
diff --git a/lmdeploy/pytorch/backends/nsa.py b/lmdeploy/pytorch/backends/nsa.py
new file mode 100644
index 0000000000..c20c80868c
--- /dev/null
+++ b/lmdeploy/pytorch/backends/nsa.py
@@ -0,0 +1,34 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+from abc import ABC, abstractmethod
+from dataclasses import dataclass
+
+from torch import Tensor
+
+
+@dataclass
+class NSAIndexMeta:
+ """Meta info of NSAIndex layer."""
+ cu_seqlen_q: Tensor
+ q_seqlens: Tensor
+ k_seqlens: Tensor
+ block_offset: Tensor
+ max_q_seqlen: int = None
+ max_kv_seqlen: int = None
+
+
+class BaseNSAIndexFP8(ABC):
+
+ @abstractmethod
+ def forward(self, q: Tensor, k: Tensor, weights: Tensor, k_cache: Tensor, k_s_cache: Tensor,
+ meta: NSAIndexMeta) -> Tensor:
+ """forward."""
+ raise NotImplementedError('Not implemented.')
+
+
+class BaseNSAIndexFP8Builder:
+
+ @staticmethod
+ @abstractmethod
+ def build(topk: int, softmax_scale: float, block_size: int = 128, fill: int = -1) -> BaseNSAIndexFP8:
+ """Build layer implementation."""
+ raise NotImplementedError('Not implemented.')
diff --git a/lmdeploy/pytorch/check_env/model.py b/lmdeploy/pytorch/check_env/model.py
index d1d096dbfe..0e5884b69b 100644
--- a/lmdeploy/pytorch/check_env/model.py
+++ b/lmdeploy/pytorch/check_env/model.py
@@ -19,8 +19,8 @@ def check_config(self, trans_version):
model_path = self.model_path
trust_remote_code = self.trust_remote_code
try:
- from transformers import AutoConfig
- config = AutoConfig.from_pretrained(model_path, trust_remote_code=trust_remote_code)
+ from lmdeploy.pytorch.transformers import config_from_pretrained
+ config = config_from_pretrained(model_path, trust_remote_code=trust_remote_code)
except Exception as e:
message = (f'Load model config with transformers=={trans_version}'
' failed. '
@@ -57,7 +57,13 @@ def check_dtype(self, config):
if not is_bf16_supported(device_type):
logger.warning('Device does not support bfloat16.')
except Exception as e:
- message = (f'Checking failed with error {e}', 'Please send issue to LMDeploy with error logs.')
+ message = (f'Checking failed with error {e}. Please send issue to LMDeploy with error logs.')
+ self.log_and_exit(e, 'Model', message=message)
+
+ try:
+ model_config.check_env_func(device_type)
+ except Exception as e:
+ message = (f'Checking failed with error {e}.')
self.log_and_exit(e, 'Model', message=message)
def check(self):
diff --git a/lmdeploy/pytorch/config.py b/lmdeploy/pytorch/config.py
index ac3459e045..403d07fb37 100644
--- a/lmdeploy/pytorch/config.py
+++ b/lmdeploy/pytorch/config.py
@@ -1,7 +1,7 @@
# Copyright (c) OpenMMLab. All rights reserved.
import enum
-from dataclasses import dataclass
-from typing import Any, Dict, List, Literal
+from dataclasses import dataclass, field
+from typing import Any, Callable, Dict, List, Literal, Optional, Tuple
import torch
@@ -183,6 +183,10 @@ def override_hf_config(hf_config: Any, hf_overrides: Dict[str, Any]):
_override_hf_config(hf_config, k, v)
+def _default_check_env(device: str):
+ pass
+
+
@dataclass
class ModelConfig:
"""Config of model."""
@@ -203,11 +207,24 @@ class ModelConfig:
llm_config: Any = None
cogvlm_style: bool = False
custom_module_map: Dict[str, setattr] = None
+
+ # flash mla
use_flash_mla: bool = False
+ use_mla_fp8_cache: bool = False
+ mla_index_topk: Optional[int] = None
+
+ # dllm
model_paradigm: str = 'ar'
dllm_mask_token: int = 0
dllm_block_length: int = None
+ # Added for deepseekv3.2 nsa index
+ # caches would be added after kv cache
+ cache_shapes: List[Tuple[List[int], torch.dtype]] = field(default_factory=list)
+
+ # check env for model-device combination
+ check_env_func: Callable = _default_check_env
+
def get_head_size(self):
"""Get head size."""
return self.head_dim
@@ -232,9 +249,9 @@ def from_pretrained(cls,
"""
from transformers import AutoConfig
+ from lmdeploy.pytorch.transformers import config_from_pretrained
from lmdeploy.utils import get_logger
-
- hf_config = AutoConfig.from_pretrained(pretrained_model_name_or_path, trust_remote_code=trust_remote_code)
+ hf_config = config_from_pretrained(pretrained_model_name_or_path, trust_remote_code=trust_remote_code)
if getattr(hf_config, 'model_type', None) in ['phi3']:
# phi3 + trust_remote_code leads to error when tp.
hf_config = AutoConfig.from_pretrained(pretrained_model_name_or_path)
diff --git a/lmdeploy/pytorch/configurations/deepseek_v2.py b/lmdeploy/pytorch/configurations/deepseek_v2.py
index f83abe38f2..f0eeb344ec 100644
--- a/lmdeploy/pytorch/configurations/deepseek_v2.py
+++ b/lmdeploy/pytorch/configurations/deepseek_v2.py
@@ -1,16 +1,38 @@
# Copyright (c) OpenMMLab. All rights reserved.
+import torch
+
from lmdeploy.pytorch.config import ModelConfig
from .builder import AutoModelConfigBuilder
from .utils import flash_mla_available
+def _check_env_v32(device: str = 'cuda'):
+ """Environment check."""
+ if device != 'cuda':
+ return
+
+ # check cuda
+ try:
+ import fast_hadamard_transform # noqa: F401
+ except ImportError:
+ raise ImportError('Deepseek V3.2 requires .')
+
+ try:
+ import flash_mla # noqa: F401
+ except ImportError:
+ raise ImportError('Deepseek V3.2 requires .')
+
+ if not hasattr(flash_mla, 'flash_mla_sparse_fwd'):
+ raise RuntimeError('Deepseek V3.2 latest with support.')
+
+
class DeepseekV2ModelConfigBuilder(AutoModelConfigBuilder):
@classmethod
def condition(cls, hf_config):
"""config."""
- return hf_config.model_type in ['deepseek_v3', 'deepseek_v2', 'kimi_k2']
+ return hf_config.model_type in ['deepseek_v3', 'deepseek_v2', 'deepseek_v32', 'kimi_k2']
@classmethod
def build(cls, hf_config, model_path: str = None, **kwargs):
@@ -26,14 +48,24 @@ def build(cls, hf_config, model_path: str = None, **kwargs):
num_key_value_heads = cls.update_num_kv_heads(hf_config, tp, num_key_value_heads)
hf_config.use_flash_mla = flash_mla_available()
- return ModelConfig(hidden_size=hf_config.hidden_size,
- num_layers=hf_config.num_hidden_layers,
- num_attention_heads=num_attention_heads,
- num_key_value_heads=num_key_value_heads,
- bos_token_id=hf_config.bos_token_id,
- eos_token_id=hf_config.eos_token_id,
- head_dim=head_dim,
- k_head_dim=k_head_dim,
- v_head_dim=v_head_dim,
- vocab_size=hf_config.vocab_size,
- use_flash_mla=hf_config.use_flash_mla)
+ config = ModelConfig(hidden_size=hf_config.hidden_size,
+ num_layers=hf_config.num_hidden_layers,
+ num_attention_heads=num_attention_heads,
+ num_key_value_heads=num_key_value_heads,
+ bos_token_id=hf_config.bos_token_id,
+ eos_token_id=hf_config.eos_token_id,
+ head_dim=head_dim,
+ k_head_dim=k_head_dim,
+ v_head_dim=v_head_dim,
+ vocab_size=hf_config.vocab_size,
+ use_flash_mla=hf_config.use_flash_mla)
+
+ if hf_config.model_type == 'deepseek_v32':
+ assert hf_config.use_flash_mla, 'DeepSeek-V3.2 requires flash_mla to be available.'
+ index_k_shape = ([hf_config.index_head_dim], torch.float8_e4m3fn)
+ index_k_scale_shape = ([1], torch.float32)
+ config.cache_shapes = [index_k_shape, index_k_scale_shape]
+ config.use_mla_fp8_cache = True
+ config.mla_index_topk = hf_config.index_topk
+ config.check_env_func = _check_env_v32
+ return config
diff --git a/lmdeploy/pytorch/engine/cache_engine.py b/lmdeploy/pytorch/engine/cache_engine.py
index d8ec198349..bbe18cd7af 100644
--- a/lmdeploy/pytorch/engine/cache_engine.py
+++ b/lmdeploy/pytorch/engine/cache_engine.py
@@ -1,7 +1,9 @@
# Copyright (c) OpenMMLab. All rights reserved.
# modify from: https://github.com/vllm-project/vllm
import json
-from typing import Dict, List, Literal, Optional, Tuple
+import math
+from dataclasses import dataclass
+from typing import Dict, List, Literal, Optional, Sequence, Tuple
import torch
@@ -20,6 +22,35 @@
logger = get_logger('lmdeploy')
+def round_up(x: int, alignment: int) -> int:
+ """Round up x to the nearest multiple of alignment."""
+ return ((x + alignment - 1) // alignment) * alignment
+
+
+@dataclass
+class CacheDesc:
+ """Cache description."""
+ shape: List[int]
+ dtype: torch.dtype
+ alignment: int = 256
+
+ def __post_init__(self):
+ self.numel = math.prod(self.shape)
+ self.size = self.numel * self.dtype.itemsize
+ self.aligned_size = round_up(self.size, self.alignment)
+
+
+def _get_kv_cache_dtype(model_config: ModelConfig):
+ kv_cache_dtype = model_config.dtype
+ if model_config.use_mla_fp8_cache:
+ kv_cache_dtype = torch.float8_e4m3fn
+ return kv_cache_dtype
+
+
+# 512*1 + 4*4 + 64*2 = 656
+MLA_FP8_HEAD_DIM = 656
+
+
class CacheEngine:
"""Host and Device memory maintainer.
@@ -50,7 +81,11 @@ def __init__(
self.block_size = cache_config.block_size
self.num_layers = model_config.num_layers
- self.kv_cache_dtype = model_config.dtype
+ self.kv_cache_dtype = _get_kv_cache_dtype(self.model_config)
+
+ if self.model_config.use_mla_fp8_cache:
+ cache_config.quant_policy = 0
+
if cache_config.quant_policy > 0:
if self.cache_config.device_type in ['cuda']:
self.kv_cache_dtype = torch.uint8
@@ -100,16 +135,21 @@ def _get_key_block_shape_impl(cls,
block_size: int,
head_size: int,
world_size: int = 1,
- quant_policy: Literal[0, 4, 8] = 0,
- local: bool = True):
+ quant_policy: Literal[0, 4, 8] = 0):
"""Get single block shape."""
attn_backend = get_backend()
dtype = model_config.dtype
num_heads = model_config.num_key_value_heads
- if local:
- assert num_heads % world_size == 0, \
- f'num_heads: {num_heads}, world_size: {world_size}'
- num_heads = num_heads // world_size
+
+ # split heads by tp
+ assert num_heads % world_size == 0, \
+ f'num_heads: {num_heads}, world_size: {world_size}'
+ num_heads = num_heads // world_size
+
+ # patch for flash mla
+ if model_config.use_mla_fp8_cache:
+ return (block_size, num_heads, MLA_FP8_HEAD_DIM)
+
if quant_policy == 4: # pack head_dim to uint8
assert head_size % 2 == 0, \
f'head_size: {head_size}, quant_policy: {quant_policy}'
@@ -122,16 +162,22 @@ def _get_value_block_shape_impl(cls,
block_size: int,
head_size: int,
world_size: int = 1,
- quant_policy: Literal[0, 4, 8] = 0,
- local: bool = True):
+ quant_policy: Literal[0, 4, 8] = 0):
"""Get single block shape."""
attn_backend = get_backend()
dtype = model_config.dtype
num_heads = model_config.num_key_value_heads
- if local:
- assert num_heads % world_size == 0, \
- f'num_heads: {num_heads}, world_size: {world_size}'
- num_heads = num_heads // world_size
+
+ # split heads by tp
+ assert num_heads % world_size == 0, \
+ f'num_heads: {num_heads}, world_size: {world_size}'
+ num_heads = num_heads // world_size
+
+ # patch for flash mla
+ if model_config.use_mla_fp8_cache:
+ # flash mla shared key and value
+ return (block_size, num_heads, 0)
+
if quant_policy == 4: # pack head_dim to uint8
assert head_size % 2 == 0, \
f'head_size: {head_size}, quant_policy: {quant_policy}'
@@ -139,86 +185,155 @@ def _get_value_block_shape_impl(cls,
return attn_backend.get_v_block_shape(block_size, num_heads, head_size, dtype)
- def get_key_block_shape(self, local: bool = False) -> Tuple[int, int, int]:
- """Get shape of key block."""
- head_size = self.model_config.k_head_dim
+ @classmethod
+ def get_k_cache_desc(cls, model_config: ModelConfig, cache_config: CacheConfig, world_size: int = 1) -> CacheDesc:
+ """Get key cache description."""
+ head_size = model_config.k_head_dim
if head_size is None:
- head_size = self.model_config.head_dim
- return self._get_key_block_shape_impl(
- self.model_config,
- block_size=self.block_size,
+ head_size = model_config.head_dim
+ shape = cls._get_key_block_shape_impl(
+ model_config,
+ block_size=cache_config.block_size,
head_size=head_size,
- world_size=self.world_size,
- quant_policy=self.cache_config.quant_policy,
- local=local,
+ world_size=world_size,
+ quant_policy=cache_config.quant_policy,
)
+ shape = list(shape)
+ dtype = _get_kv_cache_dtype(model_config)
+ if cache_config.quant_policy in (4, 8):
+ dtype = torch.uint8
+ return CacheDesc(shape=shape, dtype=dtype)
- def get_value_block_shape(self, local: bool = False) -> Tuple[int, int, int]:
- """Get shape of value block."""
- head_size = self.model_config.v_head_dim
+ @classmethod
+ def get_v_cache_desc(cls, model_config: ModelConfig, cache_config: CacheConfig, world_size: int = 1) -> CacheDesc:
+ """Get value cache description."""
+ head_size = model_config.v_head_dim
if head_size is None:
- head_size = self.model_config.head_dim
- return self._get_value_block_shape_impl(
- self.model_config,
- block_size=self.block_size,
+ head_size = model_config.head_dim
+ shape = cls._get_value_block_shape_impl(
+ model_config,
+ block_size=cache_config.block_size,
head_size=head_size,
- world_size=self.world_size,
- quant_policy=self.cache_config.quant_policy,
- local=local,
+ world_size=world_size,
+ quant_policy=cache_config.quant_policy,
)
+ shape = list(shape)
+ dtype = _get_kv_cache_dtype(model_config)
+ if cache_config.quant_policy in (4, 8):
+ dtype = torch.uint8
+ return CacheDesc(shape=shape, dtype=dtype)
- def _allocate_cache(self, num_blocks: int, device: torch.device):
- """Allocate cache implement."""
- key_block_shape = self.get_key_block_shape(local=True)
- value_block_shape = self.get_value_block_shape(local=True)
+ @classmethod
+ def get_quant_cache_descs(cls, k_cache_desc: CacheDesc, v_cache_desc: CacheDesc, model_config: ModelConfig,
+ cache_config: CacheConfig):
+ """Get quant cache descs."""
+ if cache_config.quant_policy == 0:
+ return []
- num_layers = self.num_layers
- kv_cache_dtype = self.kv_cache_dtype
+ dtype = model_config.dtype
+ key_scale_zero_shape = k_cache_desc.shape[:-1] + [2]
+ val_scale_zero_shape = v_cache_desc.shape[:-1] + [2]
+ key_scale_zero_desc = CacheDesc(shape=key_scale_zero_shape, dtype=dtype)
+ val_scale_zero_desc = CacheDesc(shape=val_scale_zero_shape, dtype=dtype)
+ return [key_scale_zero_desc, val_scale_zero_desc]
- key_cache = torch.empty(
- size=(num_layers, num_blocks, *key_block_shape),
- dtype=kv_cache_dtype,
- device=device,
- )
- value_cache = torch.empty(
- size=(num_layers, num_blocks, *value_block_shape),
- dtype=kv_cache_dtype,
- device=device,
- )
+ @classmethod
+ def get_custom_cache_descs(cls, model_config: ModelConfig, cache_config: CacheConfig) -> List[CacheDesc]:
+ """Get custom cache descs."""
+ if len(model_config.cache_shapes) == 0:
+ return []
- output = (key_cache, value_cache)
+ block_size = cache_config.block_size
- if self.cache_config.quant_policy in (4, 8):
- dtype = self.model_config.dtype
- key_sz_cache = torch.empty(
- size=(num_layers, num_blocks, *key_block_shape[:-1], 2),
- dtype=dtype,
- device=device,
- )
- val_sz_cache = torch.empty(
- size=(num_layers, num_blocks, *value_block_shape[:-1], 2),
- dtype=dtype,
- device=device,
- )
- output = output + (key_sz_cache, val_sz_cache)
+ descs = []
+ for shape, dtype in model_config.cache_shapes:
+ custom_shape = (block_size, *shape)
+ desc = CacheDesc(shape=custom_shape, dtype=dtype)
+ descs.append(desc)
+ return descs
- return output
+ @classmethod
+ def allocate_caches(cls, num_blocks: int, model_config: ModelConfig, cache_config: CacheConfig, world_size: int,
+ device: str):
+ """Allocate caches."""
+
+ num_layers = model_config.num_layers
+
+ # get all descs
+ k_cache_desc = cls.get_k_cache_desc(model_config, cache_config, world_size)
+ v_cache_desc = cls.get_v_cache_desc(model_config, cache_config, world_size)
+ quant_cache_descs = cls.get_quant_cache_descs(k_cache_desc, v_cache_desc, model_config, cache_config)
+ custom_cache_descs = cls.get_custom_cache_descs(model_config, cache_config)
+ cache_descs = [k_cache_desc, v_cache_desc] + quant_cache_descs + custom_cache_descs
+
+ # get mempool size
+ mem_pool_size = 0
+ for desc in cache_descs:
+ mem_pool_size += desc.aligned_size
+
+ # create pool
+ mem_pool = torch.zeros((num_layers, num_blocks, mem_pool_size), dtype=torch.uint8, device=device)
+
+ # slice caches
+ caches = []
+ remain_pool = mem_pool
+ for desc in cache_descs:
+ cache = remain_pool[:, :, :desc.size].view(desc.dtype).view((num_layers, num_blocks, *desc.shape))
+ remain_pool = remain_pool[:, :, desc.aligned_size:]
+ caches.append(cache)
+ return mem_pool, caches
def allocate_gpu_cache(self):
"""Allocate caches on GPU."""
- caches = self._allocate_cache(self.num_gpu_blocks, 'cuda')
- self.full_gpu_cache = caches
+ mem_pool, caches = self.allocate_caches(
+ num_blocks=self.num_gpu_blocks,
+ model_config=self.model_config,
+ cache_config=self.cache_config,
+ world_size=self.world_size,
+ device='cuda',
+ )
+ self.full_gpu_cache = mem_pool
self.local_gpu_cache = list(zip(*caches))
return self.local_gpu_cache
def allocate_cpu_cache(self):
"""Allocate caches on Host."""
- caches = self._allocate_cache(self.num_cpu_blocks, 'cpu')
-
- self.full_cpu_cache = caches
+ mem_pool, caches = self.allocate_caches(
+ num_blocks=self.num_cpu_blocks,
+ model_config=self.model_config,
+ cache_config=self.cache_config,
+ world_size=self.world_size,
+ device='cpu',
+ )
+ self.full_cpu_cache = mem_pool
self.local_cpu_cache = list(zip(*caches))
return self.local_cpu_cache
+ @staticmethod
+ def get_custom_cache_shape_impl(num_layers: int, num_blocks: int, block_size: int, shape: List[int]):
+ """Get single block shape."""
+ return (num_layers, num_blocks, block_size, *shape)
+
+ @staticmethod
+ def _allocate_single_custom_cache(shape: Sequence[int], dtype: torch.dtype, device: str):
+ """Allocate custom cache."""
+ return torch.empty(shape, dtype=dtype, device=device)
+
+ def allocate_custom_cache(self, device: str):
+ """Allocate custom caches on GPU."""
+ num_layers = self.model_config.num_layers
+ custom_caches = []
+ for shape, dtype in self.model_config.cache_shapes:
+ custom_shape = self.get_custom_cache_shape_impl(
+ num_layers=num_layers,
+ num_blocks=self.num_gpu_blocks,
+ block_size=self.block_size,
+ shape=shape,
+ )
+ custom_cache = self._allocate_single_custom_cache(shape=custom_shape, dtype=dtype, device=device)
+ custom_caches.append(custom_cache)
+ return custom_caches
+
@torch.inference_mode()
def _swap(self, src: List[torch.Tensor], dst: List[torch.Tensor], src_to_dst: Dict[int, int]):
"""Move caches from src memory to dst memory.
@@ -248,7 +363,7 @@ def swap_in(self, src_to_dst: Dict[int, int]) -> None:
Args:
src_to_dst (Dict[int, int]): Map between src and dst.
"""
- self._swap(self.full_cpu_cache, self.full_gpu_cache, src_to_dst)
+ self._swap([self.full_cpu_cache], [self.full_gpu_cache], src_to_dst)
def swap_out(self, src_to_dst: Dict[int, int]) -> None:
"""Move cache from Device to Host.
@@ -256,14 +371,10 @@ def swap_out(self, src_to_dst: Dict[int, int]) -> None:
Args:
src_to_dst (Dict[int, int]): Map between src and dst.
"""
- self._swap(self.full_gpu_cache, self.full_cpu_cache, src_to_dst)
+ self._swap([self.full_gpu_cache], [self.full_cpu_cache], src_to_dst)
@classmethod
- def get_cache_block_size(cls,
- block_size: int,
- model_config: ModelConfig,
- world_size: int = 1,
- quant_policy: int = 0) -> int:
+ def get_cache_block_size(cls, cache_config: CacheConfig, model_config: ModelConfig, world_size: int = 1) -> int:
"""Get the required cache size of the model.
Args:
@@ -273,49 +384,15 @@ def get_cache_block_size(cls,
Return:
int: Required memory size in bytes.
"""
- num_layers = model_config.num_layers
- key_head_size = model_config.k_head_dim
- value_head_size = model_config.v_head_dim
- if key_head_size is None:
- key_head_size = model_config.head_dim
- if value_head_size is None:
- value_head_size = model_config.head_dim
- key_shape = cls._get_key_block_shape_impl(
- model_config,
- block_size=block_size,
- head_size=key_head_size,
- world_size=world_size,
- local=True,
- quant_policy=quant_policy,
- )
- value_shape = cls._get_value_block_shape_impl(
- model_config,
- block_size=block_size,
- head_size=value_head_size,
+ mem_pool, _ = cls.allocate_caches(
+ num_blocks=1,
+ model_config=model_config,
+ cache_config=cache_config,
world_size=world_size,
- quant_policy=quant_policy,
- local=True,
+ device='meta',
)
- if quant_policy == 0:
- dtype = model_config.dtype
- key_block = torch.empty(key_shape, dtype=dtype, device='meta')
- value_block = torch.empty(value_shape, dtype=dtype, device='meta')
- mem_key_block = key_block.numel() * key_block.element_size()
- mem_value_block = value_block.numel() * value_block.element_size()
- elif quant_policy in (4, 8):
- key_block = torch.empty(key_shape, dtype=torch.uint8, device='meta')
- value_block = torch.empty(value_shape, dtype=torch.uint8, device='meta')
- key_scale_zero_block = torch.empty((*key_shape[:-1], 2), dtype=model_config.dtype, device='meta')
- value_scale_zero_block = torch.empty((*value_shape[:-1], 2), dtype=model_config.dtype, device='meta')
- mem_key_block = key_block.numel() * key_block.element_size() + key_scale_zero_block.numel(
- ) * key_scale_zero_block.element_size()
- mem_value_block = value_block.numel() * value_block.element_size() + value_scale_zero_block.numel(
- ) * value_scale_zero_block.element_size()
- else:
- raise ValueError(f'unsupported quant_policy {quant_policy}')
-
- total = num_layers * (mem_key_block + mem_value_block)
- return total
+
+ return mem_pool.numel() * mem_pool.element_size()
""" Metheds for PD Disaggregation Begin. """
@@ -324,7 +401,7 @@ def p2p_initialize(self, migration_init_request: DistServeInitRequest) -> DistSe
self.migration_backend_impl = MIGRATION_BACKENDS.module_dict[self.cache_config.migration_backend.name]()
migration_init_request.rank = self.rank
self.migration_backend_impl.p2p_initialize(migration_init_request)
- for i, t in enumerate(self.full_gpu_cache):
+ for i, t in enumerate([self.full_gpu_cache]):
if t.numel() == 0:
continue
register_mr_request = DistServeRegisterMRMessage(protocol=migration_init_request.protocol,
@@ -370,7 +447,7 @@ def get_assignment_batch(mr_key, block_ids, assignment_len, layer_stride, remote
remote_layer_stride = self.migration_backend_impl.links[
remote_engine_id].remote_engine_config.num_gpu_blocks * assignment_len
- for i, t in enumerate(self.full_gpu_cache):
+ for i, t in enumerate([self.full_gpu_cache]):
if t.numel() == 0:
continue
assignment_batch.extend(
diff --git a/lmdeploy/pytorch/engine/executor/base.py b/lmdeploy/pytorch/engine/executor/base.py
index 9e50843a80..0c748c29f9 100644
--- a/lmdeploy/pytorch/engine/executor/base.py
+++ b/lmdeploy/pytorch/engine/executor/base.py
@@ -164,8 +164,7 @@ def update_configs(self):
vocal_size = self.model_config.vocab_size
tp = self.dist_config.attn_config.tp
- cache_block_size = CacheEngine.get_cache_block_size(cache_config.block_size, model_config, tp,
- cache_config.quant_policy)
+ cache_block_size = CacheEngine.get_cache_block_size(cache_config, model_config, tp)
runtime_mem, max_prefill_token_num = self._get_runtime_size(free_mem, cache_block_size, vocal_size)
if cache_config.max_prefill_token_num != max_prefill_token_num:
if max_prefill_token_num <= 0:
diff --git a/lmdeploy/pytorch/engine/model_agent.py b/lmdeploy/pytorch/engine/model_agent.py
index 2ba2850c75..a48153f08a 100644
--- a/lmdeploy/pytorch/engine/model_agent.py
+++ b/lmdeploy/pytorch/engine/model_agent.py
@@ -232,6 +232,7 @@ def model_forward(
context = ctx_mgr.build_context(
inputs=inputs,
model_config=cache_engine.model_config,
+ cache_config=cache_engine.cache_config,
kv_caches=cache_engine.gpu_cache,
kv_quant_policy=cache_engine.cache_config.quant_policy,
)
diff --git a/lmdeploy/pytorch/kernels/cuda/__init__.py b/lmdeploy/pytorch/kernels/cuda/__init__.py
index f4ae57714b..006ddada66 100644
--- a/lmdeploy/pytorch/kernels/cuda/__init__.py
+++ b/lmdeploy/pytorch/kernels/cuda/__init__.py
@@ -3,7 +3,6 @@
from .alibi_pagedattention import alibi_paged_attention_fwd
from .apply_rotary_pos_emb import apply_rotary_pos_emb
from .fill_kv_cache import fill_kv_cache
-from .flash_mla import flash_mla_fwd
from .flashattention import flash_attention_fwd
from .flatten_kv_cache import flatten_kv_cache
from .fused_moe import fused_moe
diff --git a/lmdeploy/pytorch/kernels/cuda/bitonic_topk.py b/lmdeploy/pytorch/kernels/cuda/bitonic_topk.py
new file mode 100644
index 0000000000..3a30d6e44f
--- /dev/null
+++ b/lmdeploy/pytorch/kernels/cuda/bitonic_topk.py
@@ -0,0 +1,215 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import torch
+import triton
+import triton.language as tl
+from triton.language import core
+from triton.language.standard import _log2
+
+
+@triton.jit
+def _compare_and_swap(x, ids, flip, i: tl.constexpr, n_dims: tl.constexpr):
+ n_outer: tl.constexpr = x.numel >> n_dims
+ shape: tl.constexpr = [n_outer * 2**i, 2, 2**(n_dims - i - 1)]
+ y = tl.reshape(x, shape)
+ # slice left/right with 'stride' 2**(n_dims - i - 1)
+ mask = tl.arange(0, 2)[None, :, None]
+ left = tl.broadcast_to(tl.sum(y * (1 - mask), 1)[:, None, :], shape)
+ right = tl.broadcast_to(tl.sum(y * mask, 1)[:, None, :], shape)
+ left = tl.reshape(left, x.shape)
+ right = tl.reshape(right, x.shape)
+
+ # idx
+ y_idx = tl.reshape(ids, shape)
+ left_idx = tl.broadcast_to(tl.sum(y_idx * (1 - mask), 1)[:, None, :], shape)
+ right_idx = tl.broadcast_to(tl.sum(y_idx * mask, 1)[:, None, :], shape)
+ left_idx = tl.reshape(left_idx, x.shape)
+ right_idx = tl.reshape(right_idx, x.shape)
+
+ # actual compare-and-swap
+ idtype = core.get_int_dtype(bitwidth=x.dtype.primitive_bitwidth, signed=True)
+ ileft = left.to(idtype, bitcast=True)
+ iright = right.to(idtype, bitcast=True)
+ ix = x.to(idtype, bitcast=True)
+
+ cond = (left > right) ^ flip
+ cond = cond.to(tl.int1)
+
+ ret = ix ^ tl.where(cond, ileft ^ iright, tl.zeros_like(ix))
+
+ new_ids = ids ^ tl.where(cond, left_idx ^ right_idx, tl.zeros_like(ids))
+
+ return ret.to(x.dtype, bitcast=True), new_ids
+
+
+@triton.jit
+def _bitonic_merge(x, ids, stage: tl.constexpr, order: tl.constexpr, n_dims: tl.constexpr):
+ """order_type 0 == ascending order_type 1 == descending order_type 2 ==
+ alternating."""
+ n_outer: tl.constexpr = x.numel >> n_dims
+ tl.static_assert(stage <= n_dims)
+ # flip denotes whether to re-arrange sub-sequences of elements in ascending or
+ # descending order.
+ # if flip = 00000000... then all elements will be re-arranged ascendingly at this stage
+ # if flip = 00110011... then all the elements will be re-arranged alternatingly (with
+ # a stride of 2) at this stage
+ if order == 2:
+ shape: tl.constexpr = [n_outer * 2**(n_dims - 1 - stage), 2, 2**stage]
+ flip = tl.reshape(tl.broadcast_to(tl.arange(0, 2)[None, :, None], shape), x.shape)
+ else:
+ flip = order
+ # perform `stage` rounds of `compare-and-swap`
+ for i in tl.static_range(stage):
+ x, ids = _compare_and_swap(x, ids, flip, i + (n_dims - stage), n_dims)
+ return x, ids
+
+
+@triton.jit
+def argsort(x, ids, dim: tl.constexpr = None, descending: tl.constexpr = core.CONSTEXPR_0):
+ # handle default dimension or check that it is the most minor dim
+ _dim: tl.constexpr = len(x.shape) - 1 if dim is None else dim
+ tl.static_assert(_dim == len(x.shape) - 1, 'only minor dimension is currently supported')
+ # iteratively run bitonic merge-sort steps
+ n_dims: tl.constexpr = _log2(x.shape[_dim])
+
+ for i in tl.static_range(1, n_dims + 1):
+ x, ids = _bitonic_merge(x, ids, i, 2 if i < n_dims else descending, n_dims)
+ return x, ids
+
+
+@triton.jit
+def _bitonic_topk_kernel0(score_ptr,
+ seqlen_ptr,
+ out_ptr,
+ ids_ptr,
+ stride_m: tl.constexpr,
+ K: tl.constexpr,
+ fill: tl.constexpr,
+ descending: tl.constexpr = core.CONSTEXPR_0):
+ """kernel0."""
+ batch_id = tl.program_id(0).to(tl.int64)
+ block_id = tl.program_id(1).to(tl.int64)
+
+ seqlen = tl.load(seqlen_ptr + batch_id)
+
+ if block_id * K >= seqlen:
+ return
+
+ offs_k = tl.arange(0, K)
+ origin_ids = block_id * K + offs_k
+ mask = (origin_ids < seqlen)
+ score_ptrs = score_ptr + batch_id * stride_m + origin_ids
+ scores = tl.load(score_ptrs, mask=mask, other=-1e6)
+ ids = tl.where(mask, origin_ids, fill)
+ ids = origin_ids
+
+ scores, ids = argsort(scores, ids, 0, descending)
+
+ tl.store(out_ptr + batch_id * stride_m + origin_ids, scores, mask=mask)
+ tl.store(ids_ptr + batch_id * stride_m + origin_ids, ids, mask=mask)
+
+
+@triton.jit
+def _concate(a, b):
+ """concate."""
+ c = tl.join(a, b) # [k, 2]
+ c = c.trans() # [2, k]
+ # there are bugs in `tr.ravel` when triton<=3.2.0
+ c = tl.reshape(c, (a.numel + b.numel, ))
+ return c
+
+
+@triton.jit
+def _split(a, k):
+ """split."""
+ a = a.reshape(2, k)
+ a = a.trans()
+ return tl.split(a)
+
+
+@triton.jit
+def _bitonic_topk_kernel1(score_ptr,
+ ids_ptr,
+ seqlen_ptr,
+ out_ptr,
+ stride_m: tl.constexpr,
+ K: tl.constexpr,
+ fill: tl.constexpr,
+ descending: tl.constexpr = core.CONSTEXPR_0):
+ """kernel1."""
+ batch_id = tl.program_id(0).to(tl.int64)
+
+ seqlen = tl.load(seqlen_ptr + batch_id)
+ offs_k = tl.arange(0, K)
+ score_ptrs = score_ptr + batch_id * stride_m + offs_k
+ ids_ptrs = ids_ptr + batch_id * stride_m + offs_k
+
+ # initialize
+ pos = offs_k
+ mask = pos < seqlen
+ scores = tl.load(score_ptrs, mask=mask, other=-1e6)
+ ids = tl.load(ids_ptrs, mask=mask, other=fill)
+
+ pos = 2 * K - 1 - offs_k
+ score_ptrs = score_ptr + batch_id * stride_m + pos
+ ids_ptrs = ids_ptr + batch_id * stride_m + pos
+
+ stage: tl.constexpr = _log2(2 * K)
+ for k in tl.range(K, seqlen, K, num_stages=3):
+ mask = pos < seqlen
+ new_scores = tl.load(score_ptrs, mask=mask, other=-1e6)
+ new_ids = tl.load(ids_ptrs, mask=mask, other=fill)
+
+ merged_scores = _concate(scores, new_scores)
+ merged_ids = _concate(ids, new_ids)
+
+ merged_scores, merged_ids = _bitonic_merge(merged_scores, merged_ids, stage, descending, stage)
+ # merged_scores, merged_ids = argsort(merged_scores, merged_ids, 0, descending)
+
+ scores, _ = _split(merged_scores, K)
+ ids, _ = _split(merged_ids, K)
+ score_ptrs += K
+ ids_ptrs += K
+ pos += K
+
+ out_ptrs = out_ptr + batch_id * K + offs_k
+ tl.store(out_ptrs, ids)
+
+
+def bitonic_topk(scores: torch.Tensor,
+ q_seqlens: torch.Tensor,
+ kv_seqlens: torch.Tensor,
+ k: int,
+ fill: int = -1,
+ descending: bool = True):
+ """Bitnoic topk."""
+ num_tokens = scores.size(0)
+ max_kv_len = scores.size(-1)
+
+ if num_tokens != kv_seqlens.size(0):
+ repeat_kv_seqlens = torch.repeat_interleave(kv_seqlens, q_seqlens, output_size=num_tokens)
+ else:
+ repeat_kv_seqlens = kv_seqlens
+ tmp_scores = torch.empty_like(scores)
+ tmp_ids = torch.empty_like(scores, dtype=torch.int32)
+ grid = (num_tokens, triton.cdiv(max_kv_len, k))
+ _bitonic_topk_kernel0[grid](scores,
+ repeat_kv_seqlens,
+ tmp_scores,
+ tmp_ids,
+ stride_m=scores.stride(0),
+ K=k,
+ fill=fill,
+ descending=1 if descending else 0,
+ num_warps=4)
+
+ out = kv_seqlens.new_empty((num_tokens, k), dtype=torch.int32)
+ _bitonic_topk_kernel1[(num_tokens, )](tmp_scores,
+ tmp_ids,
+ repeat_kv_seqlens,
+ out,
+ stride_m=tmp_scores.stride(0),
+ K=k,
+ fill=fill,
+ descending=1 if descending else 0,
+ num_warps=4)
+ return out
diff --git a/lmdeploy/pytorch/kernels/cuda/ds_index.py b/lmdeploy/pytorch/kernels/cuda/ds_index.py
new file mode 100644
index 0000000000..51af0bc8e9
--- /dev/null
+++ b/lmdeploy/pytorch/kernels/cuda/ds_index.py
@@ -0,0 +1,167 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import torch
+import triton
+import triton.language as tl
+
+from .utils import get_device_props
+
+
+@triton.jit
+def _fp8_index_kernel(
+ q_ptr,
+ q_s_ptr,
+ k_cache_ptr,
+ k_s_cache_ptr,
+ cu_seqlen_q_ptr,
+ k_seqlen_ptr,
+ block_offset_ptr,
+ out_ptr,
+ stride_qm: tl.constexpr,
+ stride_qh: tl.constexpr,
+ stride_qd: tl.constexpr,
+ stride_qsm: tl.constexpr,
+ stride_qsh: tl.constexpr,
+ stride_kb: tl.constexpr,
+ stride_kn: tl.constexpr,
+ stride_kd: tl.constexpr,
+ stride_ksb: tl.constexpr,
+ stride_ksn: tl.constexpr,
+ stride_boff0,
+ stride_boff1: tl.constexpr,
+ stride_om: tl.constexpr,
+ stride_on: tl.constexpr,
+ max_q_seqlen,
+ BLOCK_H: tl.constexpr,
+ BLOCK_N: tl.constexpr,
+ BLOCK_D: tl.constexpr,
+ NUM_SPLIT: tl.constexpr,
+):
+ """Fp8 index kernel."""
+ m_id = tl.program_id(0).to(tl.int64)
+ split_id = tl.program_id(1).to(tl.int64)
+
+ assert stride_qd == 1
+ assert stride_kd == 1
+
+ batch_id = m_id // max_q_seqlen
+ q_id = m_id % max_q_seqlen
+ q_start = tl.load(cu_seqlen_q_ptr + batch_id)
+ q_seqlen = tl.load(cu_seqlen_q_ptr + batch_id + 1) - q_start
+ if q_id >= q_seqlen:
+ return
+
+ k_seqlen = tl.load(k_seqlen_ptr + batch_id)
+ if k_seqlen <= 0:
+ return
+
+ q_pos = q_start + q_id
+ offs_h = tl.arange(0, BLOCK_H)
+ offs_d = tl.arange(0, BLOCK_D)
+ offs_n = tl.arange(0, BLOCK_N)
+
+ q_ptrs = q_ptr + q_pos * stride_qm + offs_h[:, None] * stride_qh + offs_d[None, :] * stride_qd
+ q_s_ptrs = q_s_ptr + q_pos * stride_qsm + offs_h * stride_qsh
+ q = tl.load(q_ptrs)
+ q_s = tl.load(q_s_ptrs)
+
+ k_ptrs = k_cache_ptr + offs_n[None, :] * stride_kn + offs_d[:, None] * stride_kd
+ k_s_ptrs = k_s_cache_ptr + offs_n * stride_ksn
+ o_ptrs = out_ptr + q_pos * stride_om + offs_n * stride_on + split_id * BLOCK_N * stride_on
+ boff_ptr = block_offset_ptr + batch_id * stride_boff0 + split_id * stride_boff1
+
+ num_blocks = tl.cdiv(k_seqlen, BLOCK_N)
+ for boff_id in tl.range(split_id, num_blocks, NUM_SPLIT, num_stages=3):
+ boff = tl.load(boff_ptr)
+
+ k = tl.load(k_ptrs + boff * stride_kb)
+ k_s = tl.load(k_s_ptrs + boff * stride_ksb)
+
+ logits = tl.zeros((BLOCK_H, BLOCK_N), dtype=tl.float32)
+ logits = tl.dot(q, k, acc=logits)
+ logits = tl.maximum(logits, 0) * q_s[:, None]
+ logits_sum = tl.sum(logits, axis=0) * k_s
+
+ tl.store(o_ptrs, logits_sum, mask=offs_n + boff_id * BLOCK_N < k_seqlen)
+ boff_ptr += NUM_SPLIT * stride_boff1
+ o_ptrs += NUM_SPLIT * BLOCK_N * stride_on
+
+
+def fp8_index(q: torch.Tensor,
+ q_s: torch.Tensor,
+ k_cache: torch.Tensor,
+ k_s_cache: torch.Tensor,
+ cu_seqlen_q: torch.Tensor,
+ k_seqlens: torch.Tensor,
+ block_offset: torch.Tensor,
+ max_q_seqlen: int = None,
+ max_k_seqlen: int = None):
+ """Fp8 index.
+
+ q: (cum_seqlen, num_heads, head_dim)
+ q_s: (cum_seqlen, num_heads)
+ k_cache: (num_blocks, block_size, head_dim)
+ k_s_cache: (num_blocks, block_size)
+ cu_seqlen_q: (batch_size,)
+ cu_seqlen_k: (batch_size,)
+ block_offset: (batch_size, num_blocks)
+ """
+ assert q.dim() == 3
+ assert k_cache.dim() == 3
+ assert q_s.dim() == 2
+ assert k_s_cache.dim() == 2
+ cum_seqlen, num_heads, head_dim = q.shape
+ block_size = k_cache.size(1)
+ batch_size = k_seqlens.numel()
+ is_decoding = batch_size == cum_seqlen
+ if max_k_seqlen is None:
+ max_num_blocks = k_cache.size(0)
+ max_k_seqlen = max_num_blocks * block_size
+
+ # max q seqlen
+ if is_decoding:
+ if max_q_seqlen is None:
+ max_q_seqlen = 1
+ assert max_q_seqlen == 1
+ elif max_q_seqlen is None:
+ max_q_seqlen = cum_seqlen
+
+ assert q.stride(-1) == 1 and k_cache.stride(-1) == 1
+
+ out = q.new_empty((cum_seqlen, max_k_seqlen), dtype=torch.float32)
+
+ num_warps = 4
+ device_idx = q.device.index
+ props = get_device_props(device_idx)
+ num_sm = props['multi_processor_count']
+ # estimated occupancy 12.5%
+ warps_per_sm = props['warps_per_sm'] // 8
+ assert warps_per_sm >= num_warps
+ cta_per_sm = warps_per_sm // num_warps
+ cta_per_device = num_sm * cta_per_sm
+ # we better have a tensor to indicate batch id of each q
+ M = max_q_seqlen * batch_size
+ NUM_SPLIT = max(1, triton.cdiv(cta_per_device, M))
+ NUM_SPLIT = 1
+ grid = (M, NUM_SPLIT)
+
+ _fp8_index_kernel[grid](q,
+ q_s,
+ k_cache,
+ k_s_cache,
+ cu_seqlen_q,
+ k_seqlens,
+ block_offset,
+ out,
+ *q.stride(),
+ *q_s.stride(),
+ *k_cache.stride(),
+ *k_s_cache.stride(),
+ *block_offset.stride(),
+ *out.stride(),
+ max_q_seqlen=max_q_seqlen,
+ BLOCK_H=num_heads,
+ BLOCK_N=block_size,
+ BLOCK_D=head_dim,
+ NUM_SPLIT=NUM_SPLIT,
+ num_warps=num_warps)
+ return out
diff --git a/lmdeploy/pytorch/kernels/cuda/fill_kv_cache.py b/lmdeploy/pytorch/kernels/cuda/fill_kv_cache.py
index eeaad8f989..9eb2e76046 100644
--- a/lmdeploy/pytorch/kernels/cuda/fill_kv_cache.py
+++ b/lmdeploy/pytorch/kernels/cuda/fill_kv_cache.py
@@ -1,6 +1,7 @@
# Copyright (c) OpenMMLab. All rights reserved.
-from typing import Literal
+from typing import Literal, Optional
+import torch
import triton
import triton.language as tl
from torch import Tensor
@@ -266,9 +267,9 @@ def _fill_kv_cache_quant_kernel(
def fill_kv_cache(k_states: Tensor,
- v_states: Tensor,
+ v_states: Optional[Tensor],
k_caches: Tensor,
- v_caches: Tensor,
+ v_caches: Optional[Tensor],
q_start_loc: Tensor,
q_seq_length: Tensor,
kv_seq_length: Tensor,
@@ -285,6 +286,10 @@ def fill_kv_cache(k_states: Tensor,
b_dim, s_dim, h_dim, d_dim = (0, 2, 1, 3)
else:
raise RuntimeError('Unsupported layout.')
+ if v_states is None:
+ v_states = k_states[..., :0]
+ if v_caches is None:
+ v_caches = k_caches[..., :0]
block_offsets = block_offsets.contiguous()
batch_size = block_offsets.size(0)
@@ -386,3 +391,241 @@ def fill_kv_cache(k_states: Tensor,
num_warps=4,
num_stages=3,
)
+
+
+@triton.jit
+def _quant_blocked_fp8(x,
+ fp8_min: tl.constexpr,
+ fp8_max: tl.constexpr,
+ dtype: tl.constexpr,
+ GROUP_SIZE: tl.constexpr = 128):
+ x = x.to(tl.float32)
+ M: tl.constexpr = x.shape[0]
+ N: tl.constexpr = x.shape[1]
+ rfp8_max: tl.constexpr = 1 / fp8_max
+ x = x.reshape(M, N // GROUP_SIZE, GROUP_SIZE)
+ scale = tl.maximum(tl.max(tl.abs(x), axis=2, keep_dims=True), 1e-6) * rfp8_max
+ out = x / scale
+
+ out = tl.clamp(out, fp8_min, fp8_max)
+ out = out.to(dtype)
+ out = out.reshape(M, N)
+ scale = scale.reshape(M, N // GROUP_SIZE)
+ return out, scale
+
+
+@triton.jit
+def _fill_kv_cache_blocked_fp8_kernel(
+ KStates,
+ VStates,
+ KCaches,
+ VCaches,
+ KSCaches,
+ VSCaches,
+ cu_seqlen_q_ptr,
+ KVSeqLens,
+ BlockOffsets,
+ fp8_min: tl.constexpr,
+ fp8_max: tl.constexpr,
+ is_decoding: tl.constexpr,
+ head_dim: tl.constexpr,
+ head_dim_v: tl.constexpr,
+ stride_kss,
+ stride_ksh,
+ stride_ksd,
+ stride_vss,
+ stride_vsh,
+ stride_vsd,
+ stride_kcn: tl.constexpr,
+ stride_kcb: tl.constexpr,
+ stride_kch: tl.constexpr,
+ stride_kcd: tl.constexpr,
+ stride_vcn: tl.constexpr,
+ stride_vcb: tl.constexpr,
+ stride_vch: tl.constexpr,
+ stride_vcd: tl.constexpr,
+ stride_kscn: tl.constexpr,
+ stride_kscb: tl.constexpr,
+ stride_ksch: tl.constexpr,
+ stride_kscd: tl.constexpr,
+ stride_vscn: tl.constexpr,
+ stride_vscb: tl.constexpr,
+ stride_vsch: tl.constexpr,
+ stride_vscd: tl.constexpr,
+ stride_boff,
+ GROUP_SIZE: tl.constexpr,
+ BLOCK: tl.constexpr,
+ BLOCK_D: tl.constexpr,
+ BLOCK_DV: tl.constexpr,
+):
+ """Fill kv cache kernel."""
+ batch_id = tl.program_id(2)
+ head_id = tl.program_id(0)
+ block_id = tl.program_id(1)
+
+ q_startloc = tl.load(cu_seqlen_q_ptr + batch_id)
+ q_seqlen = tl.load(cu_seqlen_q_ptr + batch_id + 1) - q_startloc
+ kv_seqlen = tl.load(KVSeqLens + batch_id)
+ history_seqlen = kv_seqlen - q_seqlen
+
+ kv_block_id = history_seqlen // BLOCK + block_id
+
+ if kv_seqlen <= 0:
+ return
+
+ if kv_block_id * BLOCK >= kv_seqlen:
+ return
+
+ if is_decoding:
+ page_offs = tl.full((1, ), history_seqlen % BLOCK, dtype=tl.int32)
+ kv_mask = tl.full((1, ), 1, dtype=tl.int1)
+ q_offs = tl.full((1, ), q_startloc, dtype=tl.int32)
+ else:
+ page_offs = tl.arange(0, BLOCK)
+ kv_offs = kv_block_id * BLOCK + page_offs
+ kv_mask = (kv_offs >= history_seqlen) & (kv_offs < kv_seqlen)
+ token_off = q_startloc + kv_block_id * BLOCK - history_seqlen
+ q_offs = token_off + page_offs
+
+ block_off = tl.load(BlockOffsets + batch_id * stride_boff + kv_block_id)
+
+ d_off = tl.arange(0, BLOCK_D)
+ mask_ks = kv_mask[:, None]
+ mask_kc = mask_ks & (d_off[None, :] < head_dim)
+ d_off = d_off % head_dim
+
+ BLOCK_DS: tl.constexpr = (BLOCK_D + GROUP_SIZE - 1) // GROUP_SIZE
+ ds_off = tl.arange(0, BLOCK_DS)
+
+ ks_ptr = KStates + head_id * stride_ksh
+ ks_ptrs = ks_ptr + q_offs[:, None] * stride_kss + d_off[None, :] * stride_ksd
+ kc_ptr = KCaches + block_off * stride_kcn + head_id * stride_kch
+ kc_ptrs = kc_ptr + page_offs[:, None] * stride_kcb + d_off[None, :] * stride_kcd
+ ksc_ptr = KSCaches + block_off * stride_kscn + head_id * stride_ksch
+ ksc_ptrs = ksc_ptr + page_offs[:, None] * stride_kscb + ds_off[None, :] * stride_kscd
+
+ if BLOCK_DV > 0:
+ dv_off = tl.arange(0, BLOCK_DV)
+ mask_vs = kv_mask[:, None]
+ mask_vc = mask_vs & (dv_off[None, :] < head_dim_v)
+
+ BLOCK_DVS: tl.constexpr = (BLOCK_DV + GROUP_SIZE - 1) // GROUP_SIZE
+ dvs_off = tl.arange(0, BLOCK_DVS)
+
+ dv_off = dv_off % head_dim_v
+ vs_ptr = VStates + head_id * stride_vsh
+ vs_ptrs = vs_ptr + q_offs[:, None] * stride_vss + dv_off[None, :] * stride_vsd
+ vc_ptr = VCaches + block_off * stride_vcn + head_id * stride_vch
+ vc_ptrs = vc_ptr + page_offs[:, None] * stride_vcb + dv_off[None, :] * stride_vcd
+ vsc_ptr = VSCaches + block_off * stride_vscn + head_id * stride_vsch
+ vsc_ptrs = vsc_ptr + page_offs[:, None] * stride_vscb + dvs_off[None, :] * stride_vscd
+
+ k = tl.load(ks_ptrs, mask=mask_ks)
+ if BLOCK_DV > 0:
+ v = tl.load(vs_ptrs, mask=mask_vs)
+ kc, kcs = _quant_blocked_fp8(k, fp8_min, fp8_max, KCaches.dtype.element_ty, GROUP_SIZE)
+ tl.store(kc_ptrs, kc, mask=mask_kc)
+ tl.store(ksc_ptrs, kcs, mask=kv_mask[:, None] & (ds_off[None, :] < _div_up(head_dim, GROUP_SIZE)))
+ if BLOCK_DV > 0:
+ vc, vcs = _quant_blocked_fp8(v, fp8_min, fp8_max, VCaches.dtype.element_ty, GROUP_SIZE)
+ tl.store(vc_ptrs, vc, mask=mask_vc)
+ tl.store(vsc_ptrs, vcs, mask=kv_mask[:, None] & (ds_off[None, :] < _div_up(head_dim_v, GROUP_SIZE)))
+
+
+def fill_kv_cache_blocked_fp8(k_states: Tensor,
+ v_states: Optional[Tensor],
+ k_caches: Tensor,
+ v_caches: Optional[Tensor],
+ ks_caches: Tensor,
+ vs_caches: Optional[Tensor],
+ cu_seqlen_q: Tensor,
+ kv_seqlens: Tensor,
+ max_q_seqlen: int,
+ block_offsets: Tensor,
+ group_size: int = 128,
+ kv_layout: str = 'bshd'):
+ """Fill key/value state to cache for paged attention with fp8
+ quantization."""
+
+ if kv_layout == 'bshd':
+ b_dim, s_dim, h_dim, d_dim = (0, 1, 2, 3)
+ elif kv_layout == 'bhsd':
+ b_dim, s_dim, h_dim, d_dim = (0, 2, 1, 3)
+ else:
+ raise RuntimeError('Unsupported layout.')
+
+ if v_states is None:
+ v_states = k_states[..., :0]
+ if v_caches is None:
+ v_caches = k_caches[..., :0]
+ if vs_caches is None:
+ vs_caches = ks_caches[..., :0]
+
+ block_offsets = block_offsets.contiguous()
+ batch_size = block_offsets.size(0)
+ block_size = k_caches.size(s_dim)
+ num_heads = k_caches.size(h_dim)
+ head_dim = k_caches.size(d_dim)
+ head_dim_v = v_states.size(-1)
+ if max_q_seqlen == 1:
+ max_num_blocks = 1
+ else:
+ max_num_blocks = triton.cdiv(max_q_seqlen, block_size) + 1
+
+ BLOCK = block_size
+ BLOCK_D = triton.next_power_of_2(head_dim)
+ BLOCK_DV = triton.next_power_of_2(head_dim_v)
+ if k_caches.data_ptr() == v_caches.data_ptr() and head_dim_v <= head_dim:
+ BLOCK_DV = 0
+
+ dtype = k_caches.dtype
+ finfo = torch.finfo(dtype)
+ fmin = finfo.min
+ fmax = finfo.max
+
+ grid = (num_heads, max_num_blocks, batch_size)
+ is_decoding = max_q_seqlen == 1
+ _fill_kv_cache_blocked_fp8_kernel[grid](
+ k_states,
+ v_states,
+ k_caches,
+ v_caches,
+ ks_caches,
+ vs_caches,
+ cu_seqlen_q,
+ kv_seqlens,
+ block_offsets,
+ fp8_min=fmin,
+ fp8_max=fmax,
+ is_decoding=is_decoding,
+ head_dim=head_dim,
+ head_dim_v=head_dim_v,
+ stride_kss=k_states.stride(-3),
+ stride_ksh=k_states.stride(-2),
+ stride_ksd=k_states.stride(-1),
+ stride_vss=v_states.stride(-3),
+ stride_vsh=v_states.stride(-2),
+ stride_vsd=v_states.stride(-1),
+ stride_kcn=k_caches.stride(b_dim),
+ stride_kcb=k_caches.stride(s_dim),
+ stride_kch=k_caches.stride(h_dim),
+ stride_kcd=k_caches.stride(d_dim),
+ stride_vcn=v_caches.stride(b_dim),
+ stride_vcb=v_caches.stride(s_dim),
+ stride_vch=v_caches.stride(h_dim),
+ stride_vcd=v_caches.stride(d_dim),
+ stride_kscn=ks_caches.stride(b_dim),
+ stride_kscb=ks_caches.stride(s_dim),
+ stride_ksch=ks_caches.stride(h_dim),
+ stride_kscd=ks_caches.stride(d_dim),
+ stride_vscn=vs_caches.stride(b_dim),
+ stride_vscb=vs_caches.stride(s_dim),
+ stride_vsch=vs_caches.stride(h_dim),
+ stride_vscd=vs_caches.stride(d_dim),
+ stride_boff=block_offsets.stride(0),
+ GROUP_SIZE=group_size,
+ BLOCK=BLOCK,
+ BLOCK_D=BLOCK_D,
+ BLOCK_DV=BLOCK_DV,
+ num_warps=4,
+ )
diff --git a/lmdeploy/pytorch/kernels/cuda/flash_mla.py b/lmdeploy/pytorch/kernels/cuda/flash_mla.py
deleted file mode 100644
index 1a3209edeb..0000000000
--- a/lmdeploy/pytorch/kernels/cuda/flash_mla.py
+++ /dev/null
@@ -1,45 +0,0 @@
-# Copyright (c) OpenMMLab. All rights reserved.
-from typing import Optional, Tuple
-
-import torch
-
-
-def flash_mla_fwd(
- q: torch.Tensor,
- k_cache: torch.Tensor,
- block_table: torch.Tensor,
- cache_seqlens: torch.Tensor,
- head_dim_v: int,
- tile_scheduler_metadata: torch.Tensor,
- num_splits: torch.Tensor,
- softmax_scale: Optional[float] = None,
- causal: bool = False,
-) -> Tuple[torch.Tensor, torch.Tensor]:
- """
- Arguments:
- q: (batch_size, num_heads_q, head_dim).
- k_cache: (num_blocks, page_block_size, num_heads_k, head_dim).
- block_table: (batch_size, max_num_blocks_per_seq), torch.int32.
- cache_seqlens: (batch_size), torch.int32.
- head_dim_v: Head_dim of v.
- tile_scheduler_metadata: (num_sm_parts, TileSchedulerMetaDataSize), torch.int32, return by get_mla_metadata.
- num_splits: (batch_size + 1), torch.int32, return by get_mla_metadata.
- softmax_scale: float. The scaling of QK^T before applying softmax. Default to 1 / sqrt(head_dim).
- causal: bool. Whether to apply causal attention mask.
-
- Return:
- out: (batch_size, num_heads_q, head_dim_v).
- """
- import flash_mla
- out, _ = flash_mla.flash_mla_with_kvcache(
- q,
- k_cache,
- block_table,
- cache_seqlens,
- head_dim_v,
- tile_scheduler_metadata,
- num_splits,
- softmax_scale,
- causal,
- )
- return out.squeeze(1)
diff --git a/lmdeploy/pytorch/kernels/cuda/flatten_kv_cache.py b/lmdeploy/pytorch/kernels/cuda/flatten_kv_cache.py
index bd82e347ed..facb9a69f4 100644
--- a/lmdeploy/pytorch/kernels/cuda/flatten_kv_cache.py
+++ b/lmdeploy/pytorch/kernels/cuda/flatten_kv_cache.py
@@ -333,3 +333,174 @@ def flatten_kv_cache(k_caches: Tensor,
)
return k_states, v_states
+
+
+@triton.jit
+def dequant_fp8(x, scale, GROUP_SIZE: tl.constexpr):
+ """Dequant fp8."""
+ M: tl.constexpr = x.shape[0]
+ N: tl.constexpr = x.shape[1]
+ x = x.to(scale.dtype)
+ x = x.reshape(M, N // GROUP_SIZE, GROUP_SIZE)
+ scale = scale.reshape(M, N // GROUP_SIZE, 1)
+ x = x * scale
+ x = x.reshape(M, N)
+ return x
+
+
+@triton.jit
+def flatten_kv_cache_mla_fp8_kernel(
+ kc_nope_ptr,
+ kc_scale_ptr,
+ kc_pe_ptr,
+ ko_ptr,
+ start_loc_ptr,
+ seqlens_ptr,
+ block_offsets_ptr,
+ stride_kcb: tl.constexpr,
+ stride_kcs: tl.constexpr,
+ stride_kch: tl.constexpr,
+ stride_kcd: tl.constexpr,
+ stride_kcsb: tl.constexpr,
+ stride_kcss: tl.constexpr,
+ stride_kcsh: tl.constexpr,
+ stride_kcsd: tl.constexpr,
+ stride_kcpb: tl.constexpr,
+ stride_kcps: tl.constexpr,
+ stride_kcph: tl.constexpr,
+ stride_kcpd: tl.constexpr,
+ stride_koh,
+ stride_kos: tl.constexpr,
+ stride_kod: tl.constexpr,
+ stride_boff,
+ OUT_SIZE,
+ BLOCK_BS: tl.constexpr,
+ BLOCK_NOPE: tl.constexpr,
+ BLOCK_PE: tl.constexpr,
+ GROUP_SIZE: tl.constexpr,
+):
+ """Mla fp8 flatten kv cache kernel."""
+ page_id = tl.program_id(0)
+ batch_id = tl.program_id(1)
+ head_id = tl.program_id(2)
+ num_batches = tl.num_programs(1)
+
+ seqlen = tl.load(seqlens_ptr + batch_id)
+ start_loc = tl.load(start_loc_ptr + batch_id)
+ # fill last block to prevent attention nan
+ if batch_id == num_batches - 1:
+ seqlen = OUT_SIZE - start_loc
+ if page_id * BLOCK_BS >= seqlen:
+ return
+
+ b_off = tl.load(block_offsets_ptr + batch_id * stride_boff + page_id)
+
+ BLOCK_SCALE: tl.constexpr = BLOCK_NOPE // GROUP_SIZE
+ offs_bs = tl.arange(0, BLOCK_BS)
+ offs_dnope = tl.arange(0, BLOCK_NOPE)
+ offs_scale = tl.arange(0, BLOCK_SCALE)
+ offs_dpe = tl.arange(0, BLOCK_PE)
+ offs_obs = page_id * BLOCK_BS + tl.arange(0, BLOCK_BS)
+ mask_bs = offs_obs < seqlen
+
+ offs_kc = b_off * stride_kcb + offs_bs[:, None] * stride_kcs + head_id * stride_kch
+ kc_nope_ptrs = (kc_nope_ptr + offs_kc + offs_dnope[None, :] * stride_kcd)
+
+ offs_kc_scale = b_off * stride_kcsb + offs_bs[:, None] * stride_kcss + head_id * stride_kcsh
+ kc_scale_ptrs = (kc_scale_ptr + offs_kc_scale + offs_scale[None, :] * stride_kcsd)
+
+ offs_kc_pe = b_off * stride_kcpb + offs_bs[:, None] * stride_kcps + head_id * stride_kcph
+ kc_pe_ptrs = (kc_pe_ptr + offs_kc_pe + offs_dpe[None, :] * stride_kcpd)
+
+ offs_ko = head_id * stride_koh + (start_loc + offs_obs[:, None]) * stride_kos
+ ko_nope_ptrs = (ko_ptr + offs_ko + offs_dnope[None, :] * stride_kod)
+ ko_pe_ptrs = (ko_ptr + offs_ko + (BLOCK_NOPE + offs_dpe[None, :]) * stride_kod)
+
+ # nope
+ kc_nope = tl.load(kc_nope_ptrs)
+ kc_scale = tl.load(kc_scale_ptrs)
+ ko_nope = dequant_fp8(kc_nope, kc_scale, GROUP_SIZE)
+ ko_nope = ko_nope.to(ko_ptr.dtype.element_ty)
+ tl.store(ko_nope_ptrs, ko_nope, mask=mask_bs[:, None])
+
+ # pe
+ kc_pe = tl.load(kc_pe_ptrs)
+ tl.store(ko_pe_ptrs, kc_pe, mask=mask_bs[:, None])
+
+
+def flatten_kv_cache_mla_fp8(k_caches: Tensor,
+ seqlens: Tensor,
+ block_offsets: Tensor,
+ start_loc: Tensor = None,
+ out_size: int = None,
+ out_dtype: torch.dtype = None,
+ flatten_kv_layout: str = 'hsd'):
+ """This kernel is designed to support mla fp8."""
+ assert k_caches.dim() == 4
+
+ b_dim, s_dim, h_dim, d_dim = (0, 1, 2, 3)
+
+ if out_dtype is None:
+ out_dtype = torch.bfloat16
+
+ if out_size is None or out_size <= 0:
+ out_size = k_caches.size(b_dim) * k_caches.size(s_dim)
+
+ # TODO: DIRTY magic number
+ k_caches_nope = k_caches[..., :512]
+ k_caches_scale = k_caches[..., 512:512 + 16].view(torch.float32)
+ k_caches_pe = k_caches[..., 512 + 16:].view(out_dtype)
+
+ if start_loc is None:
+ start_loc = seqlens.cumsum(0) - seqlens
+
+ batch_size, num_blocks = block_offsets.size()
+ num_heads = k_caches.size(h_dim)
+ k_head_dim = 576
+ BLOCK_NOPE = 512
+ BLOCK_PE = 64
+ BLOCK_BS = k_caches.size(s_dim)
+ if flatten_kv_layout == 'hsd':
+ k_states = k_caches.new_empty(num_heads, out_size, k_head_dim, dtype=out_dtype)
+ stride_koh = k_states.stride(0)
+ stride_kos = k_states.stride(1)
+ elif flatten_kv_layout == 'shd':
+ k_states = k_caches.new_empty(out_size, num_heads, k_head_dim, dtype=out_dtype)
+ stride_koh = k_states.stride(1)
+ stride_kos = k_states.stride(0)
+ else:
+ raise RuntimeError(f'Unsupported layout: {flatten_kv_layout}.')
+
+ grid = (num_blocks, batch_size, num_heads)
+ flatten_kv_cache_mla_fp8_kernel[grid](
+ k_caches_nope,
+ k_caches_scale,
+ k_caches_pe,
+ k_states,
+ start_loc,
+ seqlens,
+ block_offsets,
+ stride_kcb=k_caches_nope.stride(b_dim),
+ stride_kcs=k_caches_nope.stride(s_dim),
+ stride_kch=k_caches_nope.stride(h_dim),
+ stride_kcd=k_caches_nope.stride(d_dim),
+ stride_kcsb=k_caches_scale.stride(b_dim),
+ stride_kcss=k_caches_scale.stride(s_dim),
+ stride_kcsh=k_caches_scale.stride(h_dim),
+ stride_kcsd=k_caches_scale.stride(d_dim),
+ stride_kcpb=k_caches_pe.stride(b_dim),
+ stride_kcps=k_caches_pe.stride(s_dim),
+ stride_kcph=k_caches_pe.stride(h_dim),
+ stride_kcpd=k_caches_pe.stride(d_dim),
+ stride_koh=stride_koh,
+ stride_kos=stride_kos,
+ stride_kod=k_states.stride(2),
+ stride_boff=block_offsets.stride(0),
+ OUT_SIZE=out_size,
+ BLOCK_BS=BLOCK_BS,
+ BLOCK_NOPE=BLOCK_NOPE,
+ BLOCK_PE=BLOCK_PE,
+ GROUP_SIZE=128,
+ )
+
+ return k_states
diff --git a/lmdeploy/pytorch/model_inputs.py b/lmdeploy/pytorch/model_inputs.py
index 13e35fd1ae..1a636ce380 100644
--- a/lmdeploy/pytorch/model_inputs.py
+++ b/lmdeploy/pytorch/model_inputs.py
@@ -9,7 +9,7 @@
# from torch import distributed as dist
import lmdeploy.pytorch.distributed as dist
from lmdeploy.pytorch.backends import get_backend
-from lmdeploy.pytorch.config import DLLMConfig, ModelConfig
+from lmdeploy.pytorch.config import CacheConfig, DLLMConfig, ModelConfig
from lmdeploy.pytorch.multimodal.data_type import MultiModalTensor
if TYPE_CHECKING:
@@ -144,6 +144,7 @@ class ModelInputs:
model_metas: List[Dict[str, Any]] = None
dp_meta: 'DPMeta' = None
enable_microbatch: bool = False
+ is_dummy: bool = False
def step(self, input_ids: torch.LongTensor, step_seqlens: torch.Tensor = None):
"""Update input ids."""
@@ -239,6 +240,7 @@ def __make_next_vision_inputs(flatten_mms: List, start: int):
end = min(max_seq_len, start + split_size)
max_q_seqlen = end - start
+ max_kv_seqlen += max_q_seqlen
if isinstance(max_q_seqlen, torch.Tensor):
max_q_seqlen = max_q_seqlen.item()
max_kv_seqlen += max_q_seqlen
@@ -299,6 +301,7 @@ class StepContext:
"""
input_ids: torch.LongTensor
model_config: ModelConfig
+ cache_config: CacheConfig
block_offsets: torch.IntTensor
position_ids: torch.LongTensor
attention_mask: torch.LongTensor
@@ -329,6 +332,7 @@ def new(
cls,
inputs: ModelInputs,
model_config: ModelConfig,
+ cache_config: CacheConfig,
kv_caches: List = None,
kv_quant_policy: Literal[0, 4, 8] = 0,
):
@@ -365,10 +369,13 @@ def new(
# seq_len + history_length
kv_seqlens = q_seqlens + history_seqlens
kv_seqlens -= inputs.num_ignored_history
+ if inputs.is_dummy:
+ kv_seqlens = torch.zeros_like(kv_seqlens)
ret = StepContext(
input_ids=inputs.input_ids,
model_config=model_config,
+ cache_config=cache_config,
block_offsets=inputs.block_offsets,
position_ids=position_ids,
input_embeddings=input_embeddings,
@@ -453,6 +460,7 @@ def build_context(
self,
inputs: ModelInputs,
model_config: ModelConfig,
+ cache_config: CacheConfig,
kv_caches: List = None,
kv_quant_policy: Literal[0, 4, 8] = 0,
):
@@ -460,6 +468,7 @@ def build_context(
return StepContext.new(
inputs,
model_config,
+ cache_config,
kv_caches,
kv_quant_policy,
)
diff --git a/lmdeploy/pytorch/models/deepseek_v2.py b/lmdeploy/pytorch/models/deepseek_v2.py
index a10e5da520..83cf98ea0c 100644
--- a/lmdeploy/pytorch/models/deepseek_v2.py
+++ b/lmdeploy/pytorch/models/deepseek_v2.py
@@ -1285,6 +1285,14 @@ def __skip_nextn(name, nextn_keys):
return True
return False
+ def __skip_layers():
+ """We might change the number of layers so we can debug the model
+ with less gpus."""
+ import re
+ matches = re.findall(r'\.layers\.(\d+)\.', name)
+ layer_id = int(matches[0])
+ return layer_id >= self.config.num_hidden_layers
+
stacked_params_mapping = [
# (param_name, shard_name, shard_id)
('.gate_up_proj', '.gate_proj', 0),
@@ -1334,6 +1342,10 @@ def __skip_nextn(name, nextn_keys):
# skip nextn
if __skip_nextn(name, nextn_keys):
continue
+
+ if __skip_layers():
+ continue
+
if self.config.tie_word_embeddings and 'lm_head.weight' in name:
continue
diff --git a/lmdeploy/pytorch/models/deepseek_v32.py b/lmdeploy/pytorch/models/deepseek_v32.py
new file mode 100644
index 0000000000..75b776249f
--- /dev/null
+++ b/lmdeploy/pytorch/models/deepseek_v32.py
@@ -0,0 +1,388 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+from typing import Any, Sequence, Tuple
+
+import torch
+from torch import nn
+
+from lmdeploy.pytorch.distributed import get_dist_manager, get_ep_world_rank
+from lmdeploy.pytorch.model_inputs import StepContextManager
+from lmdeploy.pytorch.nn import (ApplyRotaryEmb, Attention, RMSNorm, RopeType, build_rotary_embedding,
+ build_rotary_params)
+from lmdeploy.pytorch.nn.eplb import EPLBManager
+from lmdeploy.pytorch.nn.linear import build_colwise_linear, build_o_proj, build_rowwise_linear
+from lmdeploy.pytorch.nn.nsa import IndexerTopKFP8
+
+from .deepseek_v2 import (DeepseekV2Attention, DeepseekV2BMM, DeepseekV2DecoderLayer, DeepseekV2ForCausalLM,
+ DeepseekV2MLP, DeepseekV2Model, DeepseekV2MoE, yarn_get_mscale)
+
+
+def rotate_activation(x: torch.Tensor) -> torch.Tensor:
+ assert x.dtype == torch.bfloat16
+ from fast_hadamard_transform import hadamard_transform
+ hidden_size = x.size(-1)
+ return hadamard_transform(x, scale=hidden_size**-0.5)
+
+
+class Indexer(nn.Module):
+
+ def __init__(self, config: Any, dtype: torch.dtype = None, device: torch.device = None):
+ super().__init__()
+ try:
+ import fast_hadamard_transform # noqa: F401
+ except ImportError:
+ raise ImportError('Please install fast_hadamard_transform package.')
+ quant_config = getattr(config, 'quantization_config', None)
+ # self.dim: int = 2048
+ self.dim: int = config.hidden_size
+ self.n_heads: int = config.index_n_heads
+ self.n_local_heads = config.index_n_heads
+ self.head_dim: int = config.index_head_dim
+ self.rope_head_dim: int = config.qk_rope_head_dim
+ self.index_topk: int = config.index_topk
+ self.q_lora_rank: int = config.q_lora_rank
+ self.wq_b = build_colwise_linear(self.q_lora_rank,
+ self.n_heads * self.head_dim,
+ bias=False,
+ dtype=dtype,
+ device=device,
+ is_tp=False,
+ quant_config=quant_config)
+ self.wk = build_colwise_linear(self.dim,
+ self.head_dim,
+ bias=False,
+ dtype=dtype,
+ device=device,
+ is_tp=False,
+ quant_config=quant_config)
+ self.k_norm = nn.LayerNorm(self.head_dim, dtype=dtype, device=device)
+ self.weights_proj = build_colwise_linear(self.dim,
+ self.n_heads,
+ bias=False,
+ dtype=dtype,
+ device=device,
+ is_tp=False)
+ self.softmax_scale = self.head_dim**-0.5
+ self.scale_fmt = quant_config['scale_fmt']
+ self.apply_rotary_pos_emb = ApplyRotaryEmb()
+ self.indexer_topk = IndexerTopKFP8(self.index_topk, self.softmax_scale, block_size=128, fill=-1)
+
+ def forward(self,
+ x: torch.Tensor,
+ qr: torch.Tensor,
+ freqs_cis: torch.Tensor,
+ index_cache: Tuple[torch.Tensor, torch.Tensor],
+ attn_metadata: Any = None):
+ q = self.wq_b(qr)
+ q = q.unflatten(-1, (-1, self.head_dim))
+ q_pe, q_nope = torch.split(q, [self.rope_head_dim, self.head_dim - self.rope_head_dim], dim=-1)
+ k = self.wk(x)
+ k = self.k_norm(k)
+ k_pe, k_nope = torch.split(k, [self.rope_head_dim, self.head_dim - self.rope_head_dim], dim=-1)
+
+ # apply rotary embedding
+ cos, sin = freqs_cis
+ q_pe, k_pe = self.apply_rotary_pos_emb(
+ q_pe,
+ k_pe[..., None, :],
+ cos,
+ sin,
+ inplace=False,
+ )
+ k_pe = k_pe[0, :]
+ k_nope = k_nope[0, :, None]
+ q = torch.cat([q_pe, q_nope], dim=-1)
+ k = torch.cat([k_pe, k_nope], dim=-1)
+ q = rotate_activation(q)
+ k = rotate_activation(k)
+
+ weights = self.weights_proj(x) * self.n_heads**-0.5
+
+ return self.indexer_topk(q[0], k[:, 0], weights[0], index_cache[0], index_cache[1], attn_metadata=attn_metadata)
+
+
+class DeepseekV32Attention(DeepseekV2Attention):
+
+ def __init__(self, config: Any, dtype: torch.dtype = None, device: torch.device = None):
+ nn.Module.__init__(self)
+ quantization_config = getattr(config, 'quantization_config', None)
+ self.q_lora_rank = config.q_lora_rank
+ self.hidden_size = config.hidden_size
+ self.num_heads = config.num_attention_heads
+ self.qk_rope_head_dim = config.qk_rope_head_dim
+ self.kv_lora_rank = config.kv_lora_rank
+ self.v_head_dim = config.v_head_dim
+ self.qk_nope_head_dim = config.qk_nope_head_dim
+ self.q_head_dim = config.qk_nope_head_dim + config.qk_rope_head_dim
+ num_replicate_kv_heads = getattr(config, 'num_replicate_key_value_heads', 1)
+ num_key_value_heads = getattr(config, 'num_key_value_heads', 1)
+ use_flash_mla = getattr(config, 'use_flash_mla', False)
+
+ if self.q_lora_rank is None:
+ self.q_proj = build_colwise_linear(
+ self.hidden_size,
+ self.num_heads * self.q_head_dim,
+ bias=False,
+ dtype=dtype,
+ device=device,
+ is_tp=True,
+ quant_config=quantization_config,
+ dp_disable_tp=True,
+ )
+ else:
+ self.q_a_proj = build_colwise_linear(
+ self.hidden_size,
+ config.q_lora_rank,
+ bias=config.attention_bias,
+ dtype=dtype,
+ device=device,
+ is_tp=False,
+ quant_config=quantization_config,
+ )
+ self.q_a_layernorm = RMSNorm(config.q_lora_rank,
+ 1e-6,
+ quant_config=quantization_config,
+ dtype=dtype,
+ device=device)
+ self.q_b_proj = build_colwise_linear(
+ config.q_lora_rank,
+ self.num_heads * self.q_head_dim,
+ bias=False,
+ dtype=dtype,
+ device=device,
+ is_tp=True,
+ quant_config=quantization_config,
+ dp_disable_tp=True,
+ )
+
+ self.kv_a_proj_with_mqa = build_colwise_linear(
+ self.hidden_size,
+ config.kv_lora_rank + config.qk_rope_head_dim,
+ bias=config.attention_bias,
+ dtype=dtype,
+ device=device,
+ is_tp=False,
+ quant_config=quantization_config,
+ )
+ self.kv_a_layernorm = RMSNorm(config.kv_lora_rank,
+ 1e-6,
+ quant_config=quantization_config,
+ dtype=dtype,
+ device=device)
+ self.kc = DeepseekV2BMM(self.num_heads,
+ config.qk_nope_head_dim,
+ config.kv_lora_rank,
+ dtype=dtype,
+ device=device)
+
+ self.apply_rotary_pos_emb = ApplyRotaryEmb()
+
+ self.softmax_scale = self.q_head_dim**(-0.5)
+
+ if config.rope_scaling is not None:
+ mscale_all_dim = config.rope_scaling.get('mscale_all_dim', 0)
+ scaling_factor = config.rope_scaling['factor']
+ if mscale_all_dim:
+ mscale = yarn_get_mscale(scaling_factor, mscale_all_dim)
+ self.softmax_scale = self.softmax_scale * mscale * mscale
+
+ self.attn_fwd = Attention(self.num_heads,
+ config.kv_lora_rank + self.qk_rope_head_dim,
+ scale=self.softmax_scale,
+ num_kv_heads=num_key_value_heads,
+ v_head_size=config.kv_lora_rank,
+ num_replicate_kv_heads=num_replicate_kv_heads,
+ use_flash_mla=use_flash_mla)
+
+ self.vc = DeepseekV2BMM(self.num_heads, config.kv_lora_rank, self.v_head_dim, dtype=dtype, device=device)
+ self.o_proj = build_o_proj(
+ self.num_heads * self.v_head_dim,
+ self.hidden_size,
+ bias=config.attention_bias,
+ dtype=dtype,
+ device=device,
+ is_tp=True,
+ quant_config=quantization_config,
+ )
+
+ self.indexer = Indexer(config, dtype=dtype, device=device)
+
+ def _q_proj(self, hidden_states, num_heads: int, nope_size: int, pe_size: int):
+ """Q proj."""
+ q_len = hidden_states.size(1)
+
+ query_states = hidden_states.new_empty(q_len, num_heads, nope_size + pe_size)
+
+ if self.q_lora_rank is None:
+ qr = hidden_states
+ q = self.q_proj(hidden_states)
+ else:
+ qr = self.q_a_layernorm(self.q_a_proj(hidden_states))
+ q = self.q_b_proj(qr)
+ q = q.view(q_len, num_heads, self.q_head_dim)
+ # q_pe: (q_len, num_heads, qk_rope_head_dim)
+ q_nope, q_pe = torch.split(q, [self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1)
+ # q_nope: (q_len, num_heads, kv_lora_rank)
+ q_nope_out = query_states[..., :nope_size]
+ self.kc(q_nope, q_nope_out)
+ return query_states, q_pe, qr
+
+ def _kv_proj(self, hidden_states, nope_size: int):
+ """Kv proj."""
+ # (q_len, 1, nope_size + pe_size)
+ key_states = self.kv_a_proj_with_mqa(hidden_states[0, :, None])
+ # (q_len, 1, pe_size)
+ k_pe = key_states[..., nope_size:]
+ # kv_a_layernorm
+ value_states = key_states[..., :nope_size]
+ value_states = self.kv_a_layernorm(value_states)
+ key_states[..., :nope_size] = value_states
+ return key_states, value_states, k_pe
+
+ def _qkv_proj(self, hidden_states: torch.Tensor, num_heads: int):
+ """Qkv proj."""
+ nope_size = self.kv_lora_rank
+ pe_size = self.qk_rope_head_dim
+ query_states, q_pe, qr = self._q_proj(hidden_states, num_heads, nope_size, pe_size)
+ key_states, value_states, k_pe = self._kv_proj(hidden_states, nope_size)
+
+ return query_states, key_states, value_states, q_pe, k_pe, qr
+
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ rotary_pos_emb: Tuple[torch.FloatTensor, torch.FloatTensor],
+ past_key_value: Sequence[torch.Tensor] = None,
+ attn_metadata: Any = None,
+ ):
+ """Rewrite of LlamaAttention.forward."""
+ dist_ctx = get_dist_manager().current_context()
+ if dist_ctx.dp > 1:
+ num_heads = self.num_heads
+ else:
+ world_size = dist_ctx.world_size
+ num_heads = self.num_heads // world_size
+ nope_size = self.kv_lora_rank
+ q_len = hidden_states.size(1)
+
+ # qkv_proj
+ query_states, key_states, value_states, q_pe, k_pe, qr = self._qkv_proj(hidden_states, num_heads=num_heads)
+
+ cos, sin = rotary_pos_emb
+ q_pe, k_pe = self.apply_rotary_pos_emb(
+ q_pe,
+ k_pe,
+ cos,
+ sin,
+ inplace=False,
+ )
+ query_states[..., nope_size:] = q_pe
+ key_states[..., nope_size:] = k_pe
+
+ topk_indices = self.indexer(hidden_states, qr, rotary_pos_emb, past_key_value[-2:], attn_metadata=attn_metadata)
+
+ attn_output = self.attn_fwd(
+ query_states,
+ key_states,
+ value_states,
+ past_key_value[0],
+ past_key_value[0][..., :nope_size],
+ attn_metadata,
+ k_scales_zeros=None if len(past_key_value) == 2 else past_key_value[2],
+ v_scales_zeros=None if len(past_key_value) == 2 else past_key_value[3],
+ nsa_indices=topk_indices,
+ )
+ attn_bmm_out = attn_output.new_empty(q_len, num_heads, self.v_head_dim)
+
+ self.vc(attn_output, attn_bmm_out)
+ attn_output = attn_bmm_out.flatten(-2, -1)[None]
+ attn_output = self.o_proj(attn_output)
+
+ return attn_output
+
+
+class DeepseekV32DecoderLayer(DeepseekV2DecoderLayer):
+
+ def __init__(self, config: Any, layer_idx: int, dtype: torch.dtype = None, device: torch.device = None):
+ nn.Module.__init__(self)
+ self.layer_idx = layer_idx
+ quantization_config = None
+
+ # build attention layer
+ if getattr(config, 'use_mla', True):
+ self.self_attn = DeepseekV32Attention(config, dtype=dtype, device=device)
+ else:
+ # deepseek-vl2-tiny uses MHA LlamaAttention structure
+ from lmdeploy.pytorch.models.llama import LlamaAttention
+ self.self_attn = LlamaAttention(config, dtype=dtype, device=device)
+
+ # mlp
+ self.mlp = (DeepseekV2MoE(config, layer_idx, dtype=dtype, device=device) if
+ (config.n_routed_experts is not None and layer_idx >= config.first_k_dense_replace
+ and layer_idx % config.moe_layer_freq == 0) else DeepseekV2MLP(config, dtype=dtype, device=device))
+
+ # build input layer norm
+ self.input_layernorm = RMSNorm(config.hidden_size,
+ config.rms_norm_eps,
+ quant_config=quantization_config,
+ dtype=dtype,
+ device=device)
+
+ # build attention layer norm
+ self.post_attention_layernorm = RMSNorm(config.hidden_size, config.rms_norm_eps, dtype=dtype, device=device)
+
+
+class DeepseekV32Model(DeepseekV2Model):
+
+ def __init__(self, config: Any, dtype: torch.dtype = None, device: torch.device = None):
+ nn.Module.__init__(self)
+ self.config = config
+ self.padding_idx = config.pad_token_id
+ self.vocab_size = config.vocab_size
+ self.embed_tokens = nn.Embedding(config.vocab_size,
+ config.hidden_size,
+ self.padding_idx,
+ dtype=dtype,
+ device=device)
+ if get_dist_manager().current_context().dist_config.enable_eplb:
+ ep_size_, _ = get_ep_world_rank()
+ EPLBManager.init_global_eplb_metadata(ep_size_, config.n_routed_experts, config.num_hidden_layers)
+ self.layers = nn.ModuleList([
+ DeepseekV32DecoderLayer(config, layer_idx, dtype=dtype, device=device)
+ for layer_idx in range(config.num_hidden_layers)
+ ])
+
+ # build norm
+ self.norm = RMSNorm(config.hidden_size, config.rms_norm_eps, quant_config=None, dtype=dtype, device=device)
+
+ emb_type = RopeType.LinearScaling
+ rope_dim = config.qk_rope_head_dim if getattr(config, 'use_mla', True) else (config.hidden_size //
+ config.num_attention_heads)
+ rope_max_pos_emb = config.max_position_embeddings
+ rope_base = config.rope_theta
+
+ rope_params = dict(emb_type=emb_type, dim=rope_dim, max_position_embeddings=rope_max_pos_emb, base=rope_base)
+ update_params = build_rotary_params(config)
+ rope_params.update(update_params)
+ self.rotary_emb = build_rotary_embedding(**rope_params)
+
+
+class DeepseekV32ForCausalLM(DeepseekV2ForCausalLM):
+
+ def __init__(self,
+ config: Any,
+ ctx_mgr: StepContextManager,
+ dtype: torch.dtype = None,
+ device: torch.device = None):
+ nn.Module.__init__(self)
+ self.config = config
+ self.quantization_config = getattr(config, 'quantization_config', None)
+ self.dtype = dtype
+ self.ctx_mgr = ctx_mgr
+ self.model = DeepseekV32Model(config, dtype=dtype, device=device)
+ # build lm_head
+ self.lm_head = build_rowwise_linear(config.hidden_size,
+ config.vocab_size,
+ bias=False,
+ dtype=dtype,
+ device=device)
+ self._load_buffers = dict()
diff --git a/lmdeploy/pytorch/models/module_map.py b/lmdeploy/pytorch/models/module_map.py
index 498e2c6554..b90bfe3ba2 100644
--- a/lmdeploy/pytorch/models/module_map.py
+++ b/lmdeploy/pytorch/models/module_map.py
@@ -100,6 +100,9 @@
# deepseek-v3
MODULE_MAP.update({'DeepseekV3ForCausalLM': f'{LMDEPLOY_PYTORCH_MODEL_PATH}.deepseek_v2.DeepseekV2ForCausalLM'})
+# deepseek-v32
+MODULE_MAP.update({'DeepseekV32ForCausalLM': f'{LMDEPLOY_PYTORCH_MODEL_PATH}.deepseek_v32.DeepseekV32ForCausalLM'})
+
# deepseek-vl2
MODULE_MAP.update({'DeepseekVLV2ForCausalLM': f'{LMDEPLOY_PYTORCH_MODEL_PATH}.deepseek_vl2.DeepseekVLV2ForCausalLM'})
diff --git a/lmdeploy/pytorch/models/utils/cudagraph.py b/lmdeploy/pytorch/models/utils/cudagraph.py
index 065aef97d1..6d0772c175 100644
--- a/lmdeploy/pytorch/models/utils/cudagraph.py
+++ b/lmdeploy/pytorch/models/utils/cudagraph.py
@@ -1,6 +1,6 @@
# Copyright (c) OpenMMLab. All rights reserved.
from dataclasses import dataclass
-from typing import Any, Dict, List
+from typing import Any, Dict, List, Optional
import torch
from torch import Tensor
@@ -34,6 +34,9 @@ class CudaGraphMeta:
input_buffers: BuffType = None
output_buffers: BuffType = None
vocab_size: int = 1
+ use_mla_fp8_cache: bool = False
+ use_flash_mla: bool = False
+ mla_index_topk: Optional[int] = None
class CudaGraphMixin:
@@ -68,8 +71,16 @@ def make_buffers_cudagraph(self, graph_meta: CudaGraphMeta, *args, **kwargs) ->
import flash_mla
# create buffers for flash mla
+ num_attention_heads = self.config.num_attention_heads
+ index_topk = graph_meta.mla_index_topk
+ num_heads_q = None if index_topk is None else num_attention_heads
input_buffers['tile_scheduler_metadata'], input_buffers['num_splits'] = flash_mla.get_mla_metadata(
- torch.ones(max_batches, dtype=torch.int32, device=device), self.config.num_attention_heads, 1)
+ torch.ones(max_batches, dtype=torch.int32, device=device),
+ num_attention_heads,
+ num_heads_k=1,
+ num_heads_q=num_heads_q,
+ is_fp8_kvcache=graph_meta.use_mla_fp8_cache,
+ topk=index_topk)
# flash_mla requires block_offsets and kv_lens int32
input_buffers['block_offsets'] = torch.zeros((max_batches, num_blocks), dtype=torch.int32, device=device)
@@ -82,6 +93,10 @@ def make_buffers_cudagraph(self, graph_meta: CudaGraphMeta, *args, **kwargs) ->
# create buffer for cross_attn_metadata here
input_buffers['fill_seqlens'] = torch.zeros(max_batches, dtype=torch.int64, device=device)
+ input_buffers['cu_seqlens'] = torch.zeros(2, max_batches + 1, dtype=torch.int32, device=device)
+ input_buffers['cu_seqlens_q'] = input_buffers['cu_seqlens'][0]
+ input_buffers['cu_seqlens_k'] = input_buffers['cu_seqlens'][1]
+
return input_buffers
def fill_buffers_cudagraph(self, graph_meta: CudaGraphMeta, input_ids: Tensor, position_ids: Tensor,
@@ -108,6 +123,8 @@ def fill_buffers_cudagraph(self, graph_meta: CudaGraphMeta, input_ids: Tensor, p
qkv = torch.stack((q_start_loc, q_seqlens, kv_seqlens))
input_buffers['qkv_lens'].zero_()
input_buffers['qkv_lens'][:, :batch_size] = qkv
+ input_buffers['cu_seqlens_q'][1:batch_size + 1] = input_buffers['q_seqlens'][:batch_size].cumsum(0)
+ input_buffers['cu_seqlens_k'][1:batch_size + 1] = input_buffers['kv_seqlens'][:batch_size].cumsum(0)
if inputs_embeds is not None:
emb_size = inputs_embeds.size(-1)
if 'inputs_embeds' not in input_buffers:
@@ -121,10 +138,20 @@ def fill_buffers_cudagraph(self, graph_meta: CudaGraphMeta, input_ids: Tensor, p
attn_metadata.q_start_loc = input_buffers['q_start_loc']
attn_metadata.q_seqlens = input_buffers['q_seqlens']
attn_metadata.kv_seqlens = input_buffers['kv_seqlens']
+ attn_metadata.cu_seqlens_q = input_buffers['cu_seqlens_q']
+ attn_metadata.cu_seqlens_k = input_buffers['cu_seqlens_k']
if getattr(self.config, 'use_flash_mla', False) is True:
import flash_mla
- tile_scheduler_metadata, num_splits = flash_mla.get_mla_metadata(attn_metadata.kv_seqlens.to(torch.int32),
- self.config.num_attention_heads, 1)
+ num_attention_heads = self.config.num_attention_heads
+ index_topk = graph_meta.mla_index_topk
+ num_heads_q = None if index_topk is None else num_attention_heads
+ tile_scheduler_metadata, num_splits = flash_mla.get_mla_metadata(
+ attn_metadata.kv_seqlens.to(torch.int32),
+ num_attention_heads,
+ num_heads_k=1,
+ num_heads_q=num_heads_q,
+ is_fp8_kvcache=graph_meta.use_mla_fp8_cache,
+ topk=index_topk)
# here we use copy_ instead of = to avoid using new allocated mem for cuda graph
input_buffers['tile_scheduler_metadata'].copy_(tile_scheduler_metadata)
input_buffers['num_splits'][:new_batch_size + 1].copy_(num_splits[:new_batch_size + 1])
diff --git a/lmdeploy/pytorch/nn/attention.py b/lmdeploy/pytorch/nn/attention.py
index 7a1654db4b..e2b4c5c191 100644
--- a/lmdeploy/pytorch/nn/attention.py
+++ b/lmdeploy/pytorch/nn/attention.py
@@ -77,9 +77,15 @@ def forward(
k_scales_zeros: torch.Tensor = None,
v_scales_zeros: torch.Tensor = None,
s_aux: torch.Tensor = None,
+ nsa_indices: torch.Tensor = None,
inplace: bool = True,
) -> torch.Tensor:
"""forward."""
+ kwargs = dict()
+ if nsa_indices is not None:
+ kwargs['nsa_indices'] = nsa_indices
+ if s_aux is not None:
+ kwargs['learnable_sink'] = s_aux
return self.impl.forward(
query,
key,
@@ -89,8 +95,8 @@ def forward(
attn_metadata=attn_metadata,
k_scales_zeros=k_scales_zeros,
v_scales_zeros=v_scales_zeros,
- learnable_sink=s_aux,
inplace=inplace,
+ **kwargs,
)
@staticmethod
diff --git a/lmdeploy/pytorch/nn/nsa.py b/lmdeploy/pytorch/nn/nsa.py
new file mode 100644
index 0000000000..d3944bf003
--- /dev/null
+++ b/lmdeploy/pytorch/nn/nsa.py
@@ -0,0 +1,44 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+from torch import Tensor, nn
+
+from lmdeploy.pytorch.backends import OpType, get_backend
+from lmdeploy.pytorch.backends.attention import AttentionMetadata
+from lmdeploy.pytorch.backends.nsa import NSAIndexMeta
+from lmdeploy.pytorch.model_inputs import get_step_ctx_manager
+
+
+class IndexerTopKFP8(nn.Module):
+
+ def __init__(self, topk: int, softmax_scale: float, block_size: int = 128, fill: int = -1):
+ super().__init__()
+ backend = get_backend()
+ index_builder = backend.get_layer_impl_builder(OpType.NSAIndexFP8)
+ self.index_impl = index_builder.build(topk, softmax_scale, block_size, fill)
+
+ def forward(
+ self,
+ q: Tensor,
+ k: Tensor,
+ weights: Tensor,
+ k_cache: Tensor,
+ k_s_cache: Tensor,
+ attn_metadata: AttentionMetadata = None,
+ ):
+ """forward."""
+ step_ctx = get_step_ctx_manager().current_context()
+ cache_config = step_ctx.cache_config
+ max_tokens = cache_config.block_size * cache_config.num_gpu_blocks
+ is_decoding = attn_metadata.is_decoding
+ if q.size(0) == attn_metadata.kv_seqlens.size(0):
+ is_decoding = True
+ max_q_seqlen = 1 if is_decoding else q.size(0)
+ # we need to make max_kv_seqlen=max_allocated_cache_len to enable cudagraph
+ max_kv_seqlen = max_tokens if is_decoding else attn_metadata.kv_flatten_size
+ meta = NSAIndexMeta(cu_seqlen_q=attn_metadata.cu_seqlens_q,
+ q_seqlens=attn_metadata.q_seqlens,
+ k_seqlens=attn_metadata.kv_seqlens,
+ block_offset=attn_metadata.block_offsets,
+ max_q_seqlen=max_q_seqlen,
+ max_kv_seqlen=max_kv_seqlen)
+ ret = self.index_impl.forward(q, k, weights, k_cache, k_s_cache, meta=meta)
+ return ret
diff --git a/lmdeploy/pytorch/strategies/base/model_inputs.py b/lmdeploy/pytorch/strategies/base/model_inputs.py
index 795220bd02..77cda0da20 100644
--- a/lmdeploy/pytorch/strategies/base/model_inputs.py
+++ b/lmdeploy/pytorch/strategies/base/model_inputs.py
@@ -38,6 +38,7 @@ def make_dummy_inputs(batch_size: int,
max_kv_seqlen=max_kv_seqlen,
sum_kv_seqlen=num_tokens,
local_adapter_ids=local_adapter_ids,
+ is_dummy=True,
)
diff --git a/lmdeploy/pytorch/third_party/deep_gemm/__init__.py b/lmdeploy/pytorch/third_party/deep_gemm/__init__.py
index 1e734c4073..369862e60e 100644
--- a/lmdeploy/pytorch/third_party/deep_gemm/__init__.py
+++ b/lmdeploy/pytorch/third_party/deep_gemm/__init__.py
@@ -42,3 +42,43 @@ def fp8_gemm_nt(a, b, d, c, recipe=None, compiled_dim='nk', disable_ue8m0_cast=F
N, _ = b[0].shape
with _log_jit_build(M, N, K):
gemm_fp8_fp8_bf16_nt(a, b, d)
+
+
+try:
+ from deep_gemm import m_grouped_fp8_gemm_nt_contiguous
+except Exception:
+ from deep_gemm import m_grouped_gemm_fp8_fp8_bf16_nt_contiguous
+
+ def m_grouped_fp8_gemm_nt_contiguous(a, b, d, m_indices, recipe=None, compiled_dims='nk', disable_ue8m0_cast=False):
+ assert recipe is None
+ assert compiled_dims == 'nk'
+ assert disable_ue8m0_cast is False
+ return m_grouped_gemm_fp8_fp8_bf16_nt_contiguous(a, b, d, m_indices)
+
+
+try:
+ from deep_gemm import m_grouped_fp8_gemm_nt_masked
+except Exception:
+ from deep_gemm import m_grouped_gemm_fp8_fp8_bf16_nt_masked
+
+ def m_grouped_fp8_gemm_nt_masked(a,
+ b,
+ d,
+ masked_m,
+ expected_m,
+ recipe=None,
+ compiled_dims='nk',
+ disable_ue8m0_cast=False):
+ assert recipe is None
+ assert compiled_dims == 'nk'
+ assert disable_ue8m0_cast is False
+ return m_grouped_gemm_fp8_fp8_bf16_nt_masked(a, b, d, masked_m, expected_m)
+
+
+try:
+ from deep_gemm import get_mn_major_tma_aligned_tensor
+except Exception:
+ from deep_gemm import get_col_major_tma_aligned_tensor
+
+ def get_mn_major_tma_aligned_tensor(x):
+ return get_col_major_tma_aligned_tensor(x)
diff --git a/lmdeploy/pytorch/transformers/__init__.py b/lmdeploy/pytorch/transformers/__init__.py
new file mode 100644
index 0000000000..bfafdb1899
--- /dev/null
+++ b/lmdeploy/pytorch/transformers/__init__.py
@@ -0,0 +1,35 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+from functools import lru_cache
+
+from transformers import AutoConfig
+
+from lmdeploy.utils import get_logger
+
+
+@lru_cache()
+def register_config(model_type: str):
+ if model_type == 'deepseek_v32':
+ from lmdeploy.pytorch.transformers.configuration_deepseek_v32 import DeepseekV32Config
+ AutoConfig.register(DeepseekV32Config.model_type, DeepseekV32Config)
+ else:
+ logger.debug(f'Can not register config for model_type: {model_type}')
+
+
+logger = get_logger('lmdeploy')
+
+
+def config_from_pretrained(pretrained_model_name_or_path: str, **kwargs):
+ try:
+ return AutoConfig.from_pretrained(pretrained_model_name_or_path, **kwargs)
+ except ValueError as e:
+ logger.debug(f'AutoConfig.from_pretrained failed: {e}, try register config manually.')
+ # some models (dsv32) does not provide auto map for config
+ from transformers import PretrainedConfig
+ trust_remote_code = kwargs.pop('trust_remote_code', None)
+ config_dict, _ = PretrainedConfig.get_config_dict(pretrained_model_name_or_path, **kwargs)
+ model_type = config_dict.get('model_type', None)
+ if trust_remote_code is not None:
+ kwargs['trust_remote_code'] = trust_remote_code
+ register_config(model_type)
+
+ return AutoConfig.from_pretrained(pretrained_model_name_or_path, **kwargs)
diff --git a/lmdeploy/pytorch/transformers/configuration_deepseek_v32.py b/lmdeploy/pytorch/transformers/configuration_deepseek_v32.py
new file mode 100644
index 0000000000..59f91fb3ef
--- /dev/null
+++ b/lmdeploy/pytorch/transformers/configuration_deepseek_v32.py
@@ -0,0 +1,13 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+
+from transformers.models.deepseek_v3.configuration_deepseek_v3 import DeepseekV3Config
+
+
+class DeepseekV32Config(DeepseekV3Config):
+ model_type = 'deepseek_v32'
+
+ def __init__(self, index_head_dim=128, index_n_heads=64, index_topk=2048, **kwargs):
+ super().__init__(**kwargs)
+ self.index_head_dim = index_head_dim
+ self.index_n_heads = index_n_heads
+ self.index_topk = index_topk
diff --git a/tests/pytorch/kernel/test_bitonic_topk.py b/tests/pytorch/kernel/test_bitonic_topk.py
new file mode 100644
index 0000000000..dd58f73361
--- /dev/null
+++ b/tests/pytorch/kernel/test_bitonic_topk.py
@@ -0,0 +1,67 @@
+import pytest
+import torch
+
+
+class TestBitonicTopk:
+
+ @pytest.fixture
+ def device(self):
+ yield 'cuda'
+
+ @pytest.fixture
+ def k(self):
+ yield 2048
+
+ @pytest.fixture
+ def q_seqlens(self, device):
+ ret = [4, 16, 1, 32]
+ ret = torch.tensor(ret, dtype=torch.int32, device=device)
+ yield ret
+
+ @pytest.fixture
+ def kv_seqlens(self, device):
+ ret = [1024, 2048, 4096, 4096 + 133]
+ ret = torch.tensor(ret, dtype=torch.int32, device=device)
+ yield ret
+
+ @pytest.fixture
+ def batch_size(self, kv_seqlens):
+ return kv_seqlens.numel()
+
+ @pytest.fixture
+ def max_kv_len(self, kv_seqlens):
+ return kv_seqlens.max().item()
+
+ @pytest.fixture
+ def scores(self, q_seqlens, max_kv_len, device):
+ num_tokens = q_seqlens.sum().item()
+ yield torch.randn((num_tokens, max_kv_len), device=device)
+
+ @pytest.fixture
+ def gt(self, scores, q_seqlens, kv_seqlens, k):
+ batch_size = kv_seqlens.numel()
+ num_tokens, _ = scores.shape
+ topk_indices = torch.empty((num_tokens, k), dtype=torch.int32, device=scores.device)
+ topk_indices.fill_(-1)
+
+ start = 0
+ for i in range(batch_size):
+ q_seqlen = q_seqlens[i].item()
+ seqlen = kv_seqlens[i].item()
+ tmp_k = min(seqlen, k)
+ end = start + q_seqlen
+ _, topk_indices[start:end, :seqlen] = torch.topk(scores[start:end, :seqlen],
+ tmp_k,
+ largest=True,
+ sorted=True)
+ start = end
+ return topk_indices
+
+ def test_bitonic_topk(self, scores, q_seqlens, kv_seqlens, k, gt):
+ from lmdeploy.pytorch.kernels.cuda.bitonic_topk import bitonic_topk
+ out = bitonic_topk(scores, q_seqlens=q_seqlens, kv_seqlens=kv_seqlens, k=k, fill=-1)
+ gt[gt < 0] = 0
+ out[out < 0] = 0
+ gt_score = torch.gather(scores, 1, gt.to(torch.int64))
+ out_score = torch.gather(scores, 1, out.to(torch.int64))
+ torch.testing.assert_close(gt_score, out_score)
diff --git a/tests/pytorch/kernel/test_ds_index.py b/tests/pytorch/kernel/test_ds_index.py
new file mode 100644
index 0000000000..dc5ef009f5
--- /dev/null
+++ b/tests/pytorch/kernel/test_ds_index.py
@@ -0,0 +1,146 @@
+import pytest
+import torch
+
+
+def _make_A(M, K, group_size, out_dtype, device):
+ quant_A = torch.randn(M, K // group_size, group_size, dtype=torch.float32, device=device)
+ # -1 ~ 1
+ quant_A = quant_A * 2 - 1
+ # scaling abs max to fmax
+ finfo = torch.finfo(out_dtype)
+ fmax = finfo.max
+ scaling = fmax / quant_A.abs().amax(-1, keepdim=True)
+ quant_A *= scaling
+ quant_A = quant_A.to(out_dtype).to(torch.float32)
+
+ # create scale and A
+ scale = torch.randn(M, K // group_size, dtype=torch.float32, device=device)
+ scale /= fmax
+ A = quant_A * scale[..., None]
+
+ A = A.reshape(M, K)
+ quant_A = quant_A.reshape(M, K).to(out_dtype)
+ scale = scale.T.contiguous().T
+ return A, quant_A, scale
+
+
+@pytest.mark.skipif(torch.cuda.get_device_capability()[0] < 9, reason='require device with cc>=9.0')
+class TestDSIndex:
+
+ @pytest.fixture
+ def num_heads(self):
+ yield 64
+
+ @pytest.fixture
+ def head_dim(self):
+ yield 128
+
+ @pytest.fixture
+ def block_size(self):
+ yield 64
+
+ @pytest.fixture
+ def device(self):
+ yield 'cuda'
+
+ @pytest.fixture
+ def q_seqlens(self, request):
+ yield request.param
+
+ @pytest.fixture
+ def kv_seqlens(self, request):
+ yield request.param
+
+ @pytest.fixture
+ def k_seqlens(self, kv_seqlens, device):
+ yield torch.tensor(kv_seqlens, dtype=torch.int32, device=device)
+
+ @pytest.fixture
+ def cu_seqlen_q(self, q_seqlens, device):
+ yield torch.tensor([0] + list(q_seqlens), dtype=torch.int32, device=device).cumsum(0)
+
+ @pytest.fixture
+ def cu_seqlen_kv(self, kv_seqlens, device):
+ yield torch.tensor([0] + list(kv_seqlens), dtype=torch.int32, device=device).cumsum(0)
+
+ @pytest.fixture
+ def query(self, q_seqlens, num_heads, head_dim, device):
+ total_len = sum(q_seqlens)
+ fp_q, q, q_s = _make_A(total_len * num_heads, head_dim, head_dim, out_dtype=torch.float8_e4m3fn, device=device)
+ fp_q = fp_q.view(total_len, num_heads, head_dim)
+ q = q.view(total_len, num_heads, head_dim)
+ q_s = q_s.view(total_len, num_heads)
+ yield fp_q, q, q_s
+
+ @pytest.fixture
+ def q(self, query):
+ yield query[1]
+
+ @pytest.fixture
+ def q_s(self, query):
+ yield query[2]
+
+ @pytest.fixture
+ def key(self, kv_seqlens, head_dim):
+ total_len = sum(kv_seqlens)
+ fp_k, k, k_s = _make_A(total_len, head_dim, head_dim, out_dtype=torch.float8_e4m3fn, device='cuda')
+ fp_k = fp_k.view(total_len, head_dim)
+ k = k.view(total_len, head_dim)
+ k_s = k_s.view(total_len)
+ yield fp_k, k, k_s
+
+ @pytest.fixture
+ def k(self, key):
+ yield key[1]
+
+ @pytest.fixture
+ def k_s(self, key):
+ yield key[2]
+
+ @pytest.fixture
+ def cache_key(self, k, k_s, kv_seqlens, block_size, head_dim):
+ batch_size = len(kv_seqlens)
+ max_num_blocks = (max(kv_seqlens) + block_size - 1) // block_size
+
+ # get block offsets
+ batch_ids = torch.arange(batch_size, device='cuda') * max_num_blocks
+ block_ids = torch.arange(max_num_blocks, device='cuda')
+ block_offsets = (batch_ids[:, None] + block_ids[None, :])
+
+ k_cache = torch.zeros((max_num_blocks * batch_size * block_size, head_dim),
+ dtype=torch.float8_e4m3fn,
+ device='cuda')
+ k_s_cache = torch.zeros((max_num_blocks * batch_size * block_size), dtype=torch.float32, device='cuda')
+
+ k = k.split(kv_seqlens, dim=0)
+ k_s = k_s.split(kv_seqlens, dim=0)
+ for i in range(batch_size):
+ size = k[i].size(0)
+ start = i * max_num_blocks * block_size
+ end = start + size
+ k_cache[start:end] = k[i]
+ k_s_cache[start:end] = k_s[i]
+
+ k_cache = k_cache.view(batch_size * max_num_blocks, block_size, head_dim)
+ k_s_cache = k_s_cache.view(batch_size * max_num_blocks, block_size)
+
+ yield k_cache, k_s_cache, block_offsets
+
+ @pytest.fixture
+ def k_cache(self, cache_key):
+ yield cache_key[0]
+
+ @pytest.fixture
+ def k_s_cache(self, cache_key):
+ yield cache_key[1]
+
+ @pytest.fixture
+ def block_offset(self, cache_key):
+ yield cache_key[2]
+
+ @pytest.mark.parametrize('q_seqlens', [(1, 1, 1, 1), (1024, 2048, 1024, 1)], indirect=True)
+ @pytest.mark.parametrize('kv_seqlens', [(2048, 4096, 1024, 128)], indirect=True)
+ def test_fp8_index(self, q, q_s, k_cache, k_s_cache, cu_seqlen_q, k_seqlens, block_offset):
+ # gt requires tilelang, so this test just ensure the kernel works
+ from lmdeploy.pytorch.kernels.cuda.ds_index import fp8_index
+ fp8_index(q, q_s, k_cache, k_s_cache, cu_seqlen_q, k_seqlens, block_offset)
diff --git a/tests/pytorch/kernel/test_fill_kv_cache.py b/tests/pytorch/kernel/test_fill_kv_cache.py
index 51a9bf32ac..1a5bc565a0 100644
--- a/tests/pytorch/kernel/test_fill_kv_cache.py
+++ b/tests/pytorch/kernel/test_fill_kv_cache.py
@@ -265,3 +265,128 @@ def test_fill_kv_cache(self, k_states, v_states, k_caches, v_caches, k_scales_ze
torch.testing.assert_close(v_scales_zeros, gt[3])
torch.testing.assert_close(k_caches, gt[0])
torch.testing.assert_close(v_caches, gt[1])
+
+
+@pytest.mark.skipif(torch.cuda.get_device_capability()[0] < 9, reason='require device with cc>=9.0')
+class TestFillKVCacheBlockedFP8(TestFillKVCache):
+
+ @pytest.fixture
+ def quant_dtype(self):
+ yield torch.float8_e4m3fn
+
+ @pytest.fixture
+ def num_heads(self):
+ yield 4
+
+ @pytest.fixture
+ def head_dim(self):
+ yield 128
+
+ @pytest.fixture
+ def block_size(self):
+ yield 64
+
+ @pytest.fixture
+ def group_size(self):
+ yield 128
+
+ @pytest.fixture
+ def cu_seqlen_q(self, q_start_loc, q_seq_length):
+ batch_size = q_start_loc.size(0)
+ cu_seqlen = torch.zeros(batch_size + 1, dtype=torch.int32).cuda()
+ cu_seqlen[1:] = q_start_loc + q_seq_length
+ return cu_seqlen
+
+ @pytest.fixture
+ def k_caches(self, batch_size, max_num_blocks, block_size, num_heads, head_dim, quant_dtype):
+ shape = (batch_size * max_num_blocks, block_size, num_heads, head_dim)
+ yield torch.full(shape, 0, dtype=quant_dtype).cuda()
+
+ @pytest.fixture
+ def v_caches(self, k_caches):
+ yield torch.zeros_like(k_caches)
+
+ @pytest.fixture
+ def ks_caches(self, batch_size, max_num_blocks, block_size, num_heads, head_dim, group_size):
+ shape = (batch_size * max_num_blocks, block_size, num_heads, head_dim // group_size)
+ yield torch.full(shape, 0.0).cuda()
+
+ @pytest.fixture
+ def vs_caches(self, ks_caches):
+ yield torch.ones_like(ks_caches)
+
+ @pytest.fixture
+ def gt(self, k_states, v_states, group_size, quant_dtype):
+ from lmdeploy.pytorch.kernels.cuda.blocked_gemm_fp8 import quant_fp8
+ batch_size = k_states.size(0)
+ num_heads = k_states.size(1)
+ head_dim = k_states.size(2)
+
+ k_states = k_states.flatten(0, -2)
+ v_states = v_states.flatten(0, -2)
+ quant_k, quant_ks = quant_fp8(k_states, group_size=group_size, dtype=quant_dtype)
+ quant_v, quant_vs = quant_fp8(v_states, group_size=group_size, dtype=quant_dtype)
+
+ quant_k = quant_k.view(batch_size, num_heads, head_dim)
+ quant_ks = quant_ks.view(batch_size, num_heads, head_dim // group_size)
+ quant_v = quant_v.view(batch_size, num_heads, head_dim)
+ quant_vs = quant_vs.view(batch_size, num_heads, head_dim // group_size)
+
+ yield quant_k, quant_ks, quant_v, quant_vs
+
+ def uncache(self, k_caches, ks_caches, v_caches, vs_caches, cu_seqlen_q, kv_seqlens, block_offsets):
+ batch_size = block_offsets.size(0)
+ out_k = []
+ out_ks = []
+ out_v = []
+ out_vs = []
+ q_seqlens = cu_seqlen_q[1:] - cu_seqlen_q[:-1]
+ for bidx in range(batch_size):
+ seqlen = q_seqlens[bidx].item()
+ kv_len = kv_seqlens[bidx].item()
+ start = kv_len - seqlen
+ end = kv_len
+ k = k_caches[block_offsets[bidx]].reshape(-1, k_caches.size(-2), k_caches.size(-1))
+ ks = ks_caches[block_offsets[bidx]].reshape(-1, ks_caches.size(-2), ks_caches.size(-1))
+ v = v_caches[block_offsets[bidx]].reshape(-1, v_caches.size(-2), v_caches.size(-1))
+ vs = vs_caches[block_offsets[bidx]].reshape(-1, vs_caches.size(-2), vs_caches.size(-1))
+ out_k.append(k[start:end])
+ out_ks.append(ks[start:end])
+ out_v.append(v[start:end])
+ out_vs.append(vs[start:end])
+ out_k = torch.cat(out_k, dim=0)
+ out_ks = torch.cat(out_ks, dim=0)
+ out_v = torch.cat(out_v, dim=0)
+ out_vs = torch.cat(out_vs, dim=0)
+ return out_k, out_ks, out_v, out_vs
+
+ @pytest.mark.parametrize(['seq_lens', 'history_lens'], [
+ ((1, 1, 1, 1), (1, 128, 256, 200)),
+ ((1, 64, 128, 50), (1, 128, 256, 200)),
+ ],
+ indirect=True)
+ def test_fill_kv_cache(self, k_states, v_states, k_caches, v_caches, ks_caches, vs_caches, block_offsets,
+ cu_seqlen_q, kv_seq_length, max_q_seq_length, gt, group_size):
+ from lmdeploy.pytorch.kernels.cuda.fill_kv_cache import fill_kv_cache_blocked_fp8
+ fill_kv_cache_blocked_fp8(k_states,
+ v_states,
+ k_caches,
+ v_caches,
+ ks_caches,
+ vs_caches,
+ cu_seqlen_q,
+ kv_seq_length,
+ max_q_seq_length,
+ block_offsets=block_offsets,
+ group_size=group_size)
+
+ gt_k, gt_ks, gt_v, gt_vs = gt
+
+ # uncache
+ out_k, out_ks, out_v, out_vs = self.uncache(k_caches, ks_caches, v_caches, vs_caches, cu_seqlen_q,
+ kv_seq_length, block_offsets)
+
+ torch.testing.assert_close(out_k.float(), gt_k.float())
+ torch.testing.assert_close(out_ks, gt_ks)
+ torch.testing.assert_close(out_v.float(), gt_v.float())
+ torch.testing.assert_close(out_vs, gt_vs)
diff --git a/tests/pytorch/kernel/test_flatten_kv_cache.py b/tests/pytorch/kernel/test_flatten_kv_cache.py
index 5e2fffb510..80cfc34c84 100644
--- a/tests/pytorch/kernel/test_flatten_kv_cache.py
+++ b/tests/pytorch/kernel/test_flatten_kv_cache.py
@@ -170,3 +170,75 @@ def atol(self):
@pytest.fixture
def rtol(self):
yield 1e-3
+
+
+@pytest.mark.skipif(torch.cuda.get_device_capability()[0] < 9, reason='require device with cc>=9.0')
+class TestFlattenKVCacheMLAFP8(TestFlattenKVCache):
+
+ @pytest.fixture
+ def out_dtype(self):
+ yield torch.bfloat16
+
+ @pytest.fixture
+ def num_heads(self):
+ yield 1
+
+ @pytest.fixture
+ def head_dim(self):
+ yield 576
+
+ @pytest.fixture
+ def block_size(self):
+ yield 64
+
+ @pytest.fixture
+ def k_cache_mla(self, k_caches):
+ from lmdeploy.pytorch.kernels.cuda.blocked_gemm_fp8 import quant_fp8
+ num_blocks, block_size, num_heads, _ = k_caches.shape
+ k_cache_pe = k_caches[:, :, :, 512:]
+ k_cache_nope = k_caches[:, :, :, :512].flatten(0, -2)
+ k_cache_nope, k_cache_scale = quant_fp8(k_cache_nope, group_size=128)
+ k_cache_nope = k_cache_nope.view(num_blocks, block_size, num_heads, -1)
+ k_cache_scale = k_cache_scale.reshape(num_blocks, block_size, num_heads, -1).to(torch.float32)
+ dtype = k_cache_nope.dtype
+ out = torch.cat([k_cache_nope, k_cache_scale.view(dtype), k_cache_pe.view(dtype)], dim=-1)
+ yield out
+
+ def _dequant(self, k_cache_mla):
+ k_cache_nope = k_cache_mla[..., :512].to(torch.float32)
+ k_cache_scale = k_cache_mla[..., 512:512 + 16].view(torch.float32)
+ k_cache_pe = k_cache_mla[..., 512 + 16:].view(torch.bfloat16)
+ k_cache_nope = k_cache_nope.unflatten(-1, (-1, 128))
+ k_cache_scale = k_cache_scale[..., None]
+ k_cache_nope *= k_cache_scale
+ k_cache_nope = k_cache_nope.flatten(-2, -1).to(k_cache_pe.dtype)
+ k_cache = torch.cat([k_cache_nope, k_cache_pe], dim=-1)
+ return k_cache
+
+ @pytest.fixture
+ def gt(self, k_cache_mla, kv_lens, block_offsets, block_size, num_heads, out_size, head_dim):
+ k_caches = self._dequant(k_cache_mla)
+ k_states = k_caches.new_empty(num_heads, out_size, head_dim)
+ start_loc = 0
+ for kv_len, block_offs in zip(kv_lens, block_offsets):
+ remain_len = kv_len
+ for idx, _ in enumerate(range(0, kv_len, block_size)):
+ b_off = block_offs[idx]
+ block_len = min(block_size, remain_len)
+ end_loc = start_loc + block_len
+ k_block = k_caches[b_off, :block_len]
+ k_states[:, start_loc:end_loc] = k_block.transpose(0, 1)
+ start_loc = end_loc
+ remain_len -= block_len
+
+ yield k_states
+
+ def test_flatten_kv_cache(self, k_cache_mla, kv_seqlens, block_offsets, out_size, out_dtype, gt):
+ from lmdeploy.pytorch.kernels.cuda.flatten_kv_cache import flatten_kv_cache_mla_fp8
+
+ k_states = flatten_kv_cache_mla_fp8(k_cache_mla,
+ kv_seqlens,
+ block_offsets,
+ out_size=out_size,
+ out_dtype=out_dtype)
+ torch.testing.assert_close(k_states, gt)