diff --git a/README.md b/README.md index d204fddff6..498b814d19 100644 --- a/README.md +++ b/README.md @@ -142,6 +142,8 @@ LMDeploy is a toolkit for compressing, deploying, and serving LLM, developed by
  • DeepSeek-MoE (16B)
  • DeepSeek-V2 (16B, 236B)
  • DeepSeek-V2.5 (236B)
  • +
  • DeepSeek-V3 (685B)
  • +
  • DeepSeek-V3.2 (685B)
  • Mixtral (8x7B, 8x22B)
  • Gemma (2B - 7B)
  • StarCoder2 (3B - 15B)
  • diff --git a/README_ja.md b/README_ja.md index 75d05390ad..b693918c3a 100644 --- a/README_ja.md +++ b/README_ja.md @@ -129,6 +129,8 @@ LMDeploy TurboMindエンジンは卓越した推論能力を持ち、さまざ
  • DeepSeek-MoE (16B)
  • DeepSeek-V2 (16B, 236B)
  • DeepSeek-V2.5 (236B)
  • +
  • DeepSeek-V3 (685B)
  • +
  • DeepSeek-V3.2 (685B)
  • Mixtral (8x7B, 8x22B)
  • Gemma (2B - 7B)
  • StarCoder2 (3B - 15B)
  • diff --git a/README_zh-CN.md b/README_zh-CN.md index f6f10a5b42..b1bc99f46e 100644 --- a/README_zh-CN.md +++ b/README_zh-CN.md @@ -143,6 +143,8 @@ LMDeploy TurboMind 引擎拥有卓越的推理能力,在各种规模的模型
  • DeepSeek-MoE (16B)
  • DeepSeek-V2 (16B, 236B)
  • DeepSeek-V2.5 (236B)
  • +
  • DeepSeek-V3 (685B)
  • +
  • DeepSeek-V3.2 (685B)
  • Mixtral (8x7B, 8x22B)
  • Gemma (2B - 7B)
  • StarCoder2 (3B - 15B)
  • diff --git a/docs/en/supported_models/supported_models.md b/docs/en/supported_models/supported_models.md index aa28854d8a..966353b99e 100644 --- a/docs/en/supported_models/supported_models.md +++ b/docs/en/supported_models/supported_models.md @@ -90,6 +90,8 @@ The following tables detail the models supported by LMDeploy's TurboMind engine | DeepSeek-MoE | 16B | LLM | Yes | No | No | No | No | | DeepSeek-V2 | 16B, 236B | LLM | Yes | No | No | No | No | | DeepSeek-V2.5 | 236B | LLM | Yes | No | No | No | No | +| DeepSeek-V3 | 685B | LLM | Yes | No | No | No | No | +| DeepSeek-V3.2 | 685B | LLM | Yes | No | No | No | No | | DeepSeek-VL2 | 3B - 27B | MLLM | Yes | No | No | No | No | | MiniCPM3 | 4B | LLM | Yes | Yes | Yes | No | No | | MiniCPM-V-2_6 | 8B | LLM | Yes | No | No | No | Yes | diff --git a/docs/zh_cn/supported_models/supported_models.md b/docs/zh_cn/supported_models/supported_models.md index 8e9e3fef20..27bd84b331 100644 --- a/docs/zh_cn/supported_models/supported_models.md +++ b/docs/zh_cn/supported_models/supported_models.md @@ -90,6 +90,8 @@ | DeepSeek-MoE | 16B | LLM | Yes | No | No | No | No | | DeepSeek-V2 | 16B, 236B | LLM | Yes | No | No | No | No | | DeepSeek-V2.5 | 236B | LLM | Yes | No | No | No | No | +| DeepSeek-V3 | 685B | LLM | Yes | No | No | No | No | +| DeepSeek-V3.2 | 685B | LLM | Yes | No | No | No | No | | DeepSeek-VL2 | 3B - 27B | MLLM | Yes | No | No | No | No | | MiniCPM3 | 4B | LLM | Yes | Yes | Yes | No | No | | MiniCPM-V-2_6 | 8B | LLM | Yes | No | No | No | Yes | diff --git a/lmdeploy/pytorch/backends/attention.py b/lmdeploy/pytorch/backends/attention.py index 0c842e8871..63941aebfb 100644 --- a/lmdeploy/pytorch/backends/attention.py +++ b/lmdeploy/pytorch/backends/attention.py @@ -15,6 +15,8 @@ class AttentionMetadata: q_seqlens: torch.Tensor = None kv_seqlens: torch.Tensor = None fill_seqlens: torch.Tensor = None + cu_seqlens_q: torch.Tensor = None + cu_seqlens_k: torch.Tensor = None quant_policy: Literal[0, 4, 8] = 0 @@ -70,6 +72,7 @@ def forward( k_scales_zeros: torch.Tensor = None, v_scales_zeros: torch.Tensor = None, learnable_sink: torch.Tensor = None, + nsa_indices: torch.Tensor = None, inplace: bool = False, ) -> torch.Tensor: """forward.""" diff --git a/lmdeploy/pytorch/backends/base.py b/lmdeploy/pytorch/backends/base.py index fb3ff9ded7..937cdf01ee 100644 --- a/lmdeploy/pytorch/backends/base.py +++ b/lmdeploy/pytorch/backends/base.py @@ -31,6 +31,7 @@ class OpType(Enum): FusedMoEW8A8 = auto() LinearBlockedF8 = auto() FusedMoEBlockedF8 = auto() + NSAIndexFP8 = auto() class OpsBackend(ABC): diff --git a/lmdeploy/pytorch/backends/cuda/attention.py b/lmdeploy/pytorch/backends/cuda/attention.py index b241c384b2..69d8a47ea0 100644 --- a/lmdeploy/pytorch/backends/cuda/attention.py +++ b/lmdeploy/pytorch/backends/cuda/attention.py @@ -107,6 +107,7 @@ def forward( v_scales_zeros: torch.Tensor = None, learnable_sink: torch.Tensor = None, inplace: bool = True, + **kwargs, ) -> torch.Tensor: """forward.""" block_offsets = attn_metadata.block_offsets @@ -231,6 +232,69 @@ def use_fa3_warning(): return False +def _try_dynamic_compile(func, *args, **kwargs): + """Try compile.""" + try: + compiled_func = torch.compile(func, dynamic=True) + compiled_func(*args, **kwargs) + return compiled_func + except Exception: + return func + + +class NSAIndicesUpdater: + """NSA indices updater. + + Flash MLA sparse attention requires different indice format for prefill and decoding. This module is used to update + the indices to meet the requirements. + """ + + def __init__(self): + self._update_decode_func = None + self._update_prefill_func = None + + def _update_decode_impl(self, nsa_indices: torch.Tensor, block_offsets: torch.Tensor, + block_size: int) -> torch.Tensor: + """Update for decode impl.""" + block_ids = nsa_indices // block_size + block_ids = block_ids.clamp_min(0) + block_ids = block_offsets.gather(1, block_ids) + block_remain = nsa_indices % block_size + ret = block_ids * block_size + block_remain + ret[nsa_indices < 0] = -1 + return ret[:, None] + + def update_decode(self, nsa_indices: torch.Tensor, block_offsets: torch.Tensor, block_size: int) -> torch.Tensor: + """Update for decode.""" + if self._update_decode_func is None: + self._update_decode_func = _try_dynamic_compile(self._update_decode_impl, nsa_indices, block_offsets, + block_size) + + return self._update_decode_func(nsa_indices, block_offsets, block_size) + + def _update_prefill_impl(self, nsa_indices: torch.Tensor, q_seqlens: torch.Tensor, cu_seqlens_k: torch.Tensor): + """Update for prefill impl.""" + num_tokens = nsa_indices.size(0) + repeat_cu_seqlens_k = torch.repeat_interleave(cu_seqlens_k[:-1], q_seqlens, output_size=num_tokens) + neg_mask = nsa_indices < 0 + nsa_indices = nsa_indices + repeat_cu_seqlens_k[:, None] + nsa_indices[neg_mask] = -1 + return nsa_indices[:, None] + + def update_prefill(self, nsa_indices: torch.Tensor, q_seqlens: torch.Tensor, cu_seqlens_k: torch.Tensor): + """Update for prefill.""" + if self._update_prefill_func is None: + self._update_prefill_func = _try_dynamic_compile(self._update_prefill_impl, nsa_indices, q_seqlens, + cu_seqlens_k) + + return self._update_prefill_func(nsa_indices, q_seqlens, cu_seqlens_k) + + @staticmethod + @functools.lru_cache(maxsize=None) + def build(): + return NSAIndicesUpdater() + + class FlashMLAImpl(TritonAttentionImpl): def __init__( @@ -262,47 +326,264 @@ def __init__( **kwargs, ) - from lmdeploy.pytorch.kernels.cuda import flash_mla_fwd - self.flash_mla_fwd = flash_mla_fwd + import flash_mla + + from lmdeploy.pytorch.kernels.cuda.fill_kv_cache import fill_kv_cache_blocked_fp8 + from lmdeploy.pytorch.kernels.cuda.flatten_kv_cache import flatten_kv_cache_mla_fp8 + self.flash_mla_with_kvcache = flash_mla.flash_mla_with_kvcache + self.flash_mla_sparse_fwd = None + self.fill_kv_cache_blocked_fp8 = fill_kv_cache_blocked_fp8 + self.flatten_kv_cache_mla_fp8 = flatten_kv_cache_mla_fp8 assert num_kv_heads == 1, 'MLA requires num kv heads equal to 1' - use_fa3_warning() - def forward( + self.nsa_updater = NSAIndicesUpdater.build() + + def _get_flash_mla_sparse_fwd(self): + if self.flash_mla_sparse_fwd is not None: + return self.flash_mla_sparse_fwd + + try: + import flash_mla + self.flash_mla_sparse_fwd = flash_mla.flash_mla_sparse_fwd + return self.flash_mla_sparse_fwd + except Exception: + logger.exception('Can not import flash_mla_sparse_fwd from flash_mla.') + + def flash_mla_decoding( self, query: torch.Tensor, - key: torch.Tensor, - value: torch.Tensor, k_cache: torch.Tensor, - v_cache: torch.Tensor, + nsa_indices: torch.Tensor, attn_metadata: TritonAttentionMetadata, - k_scales_zeros: torch.Tensor = None, - v_scales_zeros: torch.Tensor = None, - learnable_sink: torch.Tensor = None, - inplace: bool = True, - ) -> torch.Tensor: - """forward.""" + ): + """Flash mla decoding.""" + causal = self.causal + kv_seqlens = attn_metadata.kv_seqlens + block_offsets = attn_metadata.block_offsets + is_fp8_kvcache = k_cache.dtype == torch.float8_e4m3fn + + query = query.unsqueeze(1) + if kv_seqlens.dtype == torch.int64: + kv_seqlens = kv_seqlens.to(torch.int32) + + # update nsa indice according to flash-mla requirement + if nsa_indices is not None: + block_size = k_cache.size(1) + nsa_indices = self.nsa_updater.update_decode(nsa_indices, block_offsets, block_size) + causal = False + + attn_output, _ = self.flash_mla_with_kvcache(query, + k_cache=k_cache, + block_table=block_offsets, + cache_seqlens=kv_seqlens, + head_dim_v=self.v_head_size, + softmax_scale=self.scale, + tile_scheduler_metadata=attn_metadata.tile_scheduler_metadata, + num_splits=attn_metadata.num_splits, + causal=causal, + is_fp8_kvcache=is_fp8_kvcache, + indices=nsa_indices) + attn_output = attn_output.squeeze(1) + return attn_output + + def flash_mla_prefill(self, query: torch.Tensor, flatten_k: torch.Tensor, nsa_indices: torch.Tensor, + attn_metadata: TritonAttentionMetadata) -> torch.Tensor: + """Flash mla prefill, only used in sparse attention.""" + q_seqlens = attn_metadata.q_seqlens + flash_mla_sparse_fwd = self._get_flash_mla_sparse_fwd() + + num_q_heads = query.size(1) + # flash_mla_sparse_fwd requires query heads to be multiple of 64 + if num_q_heads % 64 != 0: + query = torch.nn.functional.pad(query, (0, 0, 0, 64 - num_q_heads % 64)) + + nsa_indices = self.nsa_updater.update_prefill(nsa_indices, q_seqlens, attn_metadata.cu_seqlens_k) + output = flash_mla_sparse_fwd( + query, + flatten_k, + nsa_indices, + sm_scale=self.scale, + ) + attn_output = output[0] + attn_output = attn_output[:, :num_q_heads] + return attn_output + def flash_attn_triton( + self, + query: torch.Tensor, + flatten_k: torch.Tensor, + flatten_v: torch.Tensor, + attn_metadata: TritonAttentionMetadata, + ): + """Triton flash attention, used if flash-attn is not available.""" + q_start_loc = attn_metadata.q_start_loc + q_seqlens = attn_metadata.q_seqlens + kv_start_loc = attn_metadata.kv_start_loc + kv_seqlens = attn_metadata.kv_seqlens + max_q_seqlen = query.numel() // (query.size(-1) * query.size(-2)) + + q_shape = query.shape + o_shape = q_shape[:-1] + (self.v_head_size, ) + attn_output = query.new_empty(o_shape) + self.flash_attention_fwd( + query, + flatten_k, + flatten_v, + attn_output, + q_start_loc=q_start_loc, + q_seqlens=q_seqlens, + kv_start_loc=kv_start_loc, + kv_seqlens=kv_seqlens, + max_seqlen=max_q_seqlen, + window_size=self.sliding_window, + sm_scale=self.scale, + logit_softcapping=self.logit_softcapping, + causal=self.causal, + ) + + return attn_output + + def flash_attn_fa3( + self, + query: torch.Tensor, + flatten_k: torch.Tensor, + attn_metadata: TritonAttentionMetadata, + ): + """Flash attention 3, used if flash-attn 3 is available.""" + max_q_seqlen = query.numel() // (query.size(-1) * query.size(-2)) + kv_flatten_size = attn_metadata.kv_flatten_size + causal = self.causal + q_rope = query[:, :, self.v_head_size:] + q_nope = query[:, :, :self.v_head_size] + k_rope = flatten_k.view(kv_flatten_size, self.num_kv_heads, -1)[:, :, self.v_head_size:] + c_kv = flatten_k.view(kv_flatten_size, self.num_kv_heads, -1)[:, :, :self.v_head_size] + from lmdeploy.pytorch.third_party.flash_attn_interface import flash_attn_varlen_func + attn_output = flash_attn_varlen_func( + q=q_rope, + k=k_rope, + v=c_kv, + qv=q_nope, + cu_seqlens_q=attn_metadata.cu_seqlens_q, + cu_seqlens_k=attn_metadata.cu_seqlens_k, + max_seqlen_q=max_q_seqlen, + max_seqlen_k=kv_flatten_size, + softmax_scale=self.scale, + causal=causal, + window_size=(-1, -1) if self.sliding_window is None else self.sliding_window, + softcap=-1.0 if self.logit_softcapping is None else self.logit_softcapping, + ) + return attn_output + + def run_flatten_kv_cache(self, + k_cache: torch.Tensor, + v_cache: torch.Tensor, + attn_metadata: TritonAttentionMetadata, + out_dtype: torch.dtype, + is_nsa: bool, + k_scales_zeros: torch.Tensor = None, + v_scales_zeros: torch.Tensor = None): + """Flatten kv cache for prefill.""" + + kv_start_loc = attn_metadata.kv_start_loc + kv_seqlens = attn_metadata.kv_seqlens + block_offsets = attn_metadata.block_offsets + kv_flatten_size = attn_metadata.kv_flatten_size + quant_policy = attn_metadata.quant_policy + is_fp8_kvcache = k_cache.dtype == torch.float8_e4m3fn + BLOCK_BS = k_cache.size(1) + + # pad one more block to avoid invalid kv visit + out_size = (_cdiv(kv_flatten_size, BLOCK_BS) * BLOCK_BS + BLOCK_BS) + flatten_kv_layout = 'shd' if use_fa3 or is_nsa else 'hsd' + if is_fp8_kvcache: + flatten_k = self.flatten_kv_cache_mla_fp8( + k_cache, + kv_seqlens, + block_offsets, + start_loc=kv_start_loc, + out_size=out_size, + out_dtype=out_dtype, + flatten_kv_layout=flatten_kv_layout, + ) + flatten_v = flatten_k[..., :512] + else: + flatten_k, flatten_v = self.flatten_kv_cache( + k_cache, + v_cache, + kv_seqlens, + block_offsets, + start_loc=kv_start_loc, + out_size=kv_flatten_size if use_fa3 else out_size, + out_dtype=out_dtype, + k_scales_zeros=k_scales_zeros, + v_scales_zeros=v_scales_zeros, + quant_policy=quant_policy, + flatten_kv_layout=flatten_kv_layout, + ) + + return flatten_k, flatten_v + + def run_fill_kv_cache(self, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + k_cache: torch.Tensor, + v_cache: torch.Tensor, + attn_metadata: TritonAttentionMetadata, + k_scales_zeros: torch.Tensor = None, + v_scales_zeros: torch.Tensor = None): + """Fill kv cache.""" block_offsets = attn_metadata.block_offsets q_start_loc = attn_metadata.q_start_loc fill_q_start_loc = q_start_loc q_seqlens = attn_metadata.q_seqlens fill_seqlens = q_seqlens - kv_start_loc = attn_metadata.kv_start_loc kv_seqlens = attn_metadata.kv_seqlens - kv_flatten_size = attn_metadata.kv_flatten_size quant_policy = attn_metadata.quant_policy + + # max_q_seqlen if attn_metadata.is_decoding: max_q_seqlen = 1 else: max_q_seqlen = query.numel() // (query.size(-1) * query.size(-2)) + + # fill_max_q_seqlen fill_max_q_seqlen = max_q_seqlen if attn_metadata.fill_seqlens is not None: fill_seqlens = attn_metadata.fill_seqlens fill_max_q_seqlen = key.numel() // (key.size(-1) * key.size(-2)) fill_q_start_loc = fill_seqlens.cumsum(0) - fill_seqlens - # fill kv cache - if key is not None and value is not None: + is_fp8_kvcache = k_cache.dtype == torch.float8_e4m3fn + if is_fp8_kvcache: + k_cache_scale = k_cache[..., 512:512 + 16].view(torch.float32) + k_cache_nope = k_cache[..., :512] + k_cache_pe = k_cache[..., 512 + 16:].view(key.dtype) + self.fill_kv_cache_blocked_fp8( + key[..., :512], + None, + k_cache_nope, + None, + k_cache_scale, + None, + cu_seqlen_q=attn_metadata.cu_seqlens_q, + kv_seqlens=attn_metadata.kv_seqlens, + max_q_seqlen=max_q_seqlen, + block_offsets=block_offsets, + group_size=128, + ) + self.fill_kv_cache( + key[..., 512:], + None, + k_cache_pe, + None, + fill_q_start_loc, + fill_seqlens, + kv_seq_length=kv_seqlens, + max_q_seq_length=fill_max_q_seqlen, + block_offsets=block_offsets, + ) + else: self.fill_kv_cache( key, value, @@ -318,77 +599,57 @@ def forward( quant_policy=quant_policy, ) - q_shape = query.shape - o_shape = q_shape[:-1] + (self.v_head_size, ) + def forward( + self, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + k_cache: torch.Tensor, + v_cache: torch.Tensor, + attn_metadata: TritonAttentionMetadata, + k_scales_zeros: torch.Tensor = None, + v_scales_zeros: torch.Tensor = None, + nsa_indices: torch.Tensor = None, + **kwargs, + ) -> torch.Tensor: + """forward.""" + + # check nsa + is_fp8_kvcache = k_cache.dtype == torch.float8_e4m3fn + is_nsa = nsa_indices is not None + if is_nsa: + assert is_fp8_kvcache + + # fill kv cache + self.run_fill_kv_cache( + query, + key, + value, + k_cache, + v_cache, + attn_metadata, + k_scales_zeros=k_scales_zeros, + v_scales_zeros=v_scales_zeros, + ) is_decoding = attn_metadata.is_decoding if is_decoding: - query = query.unsqueeze(1) - if kv_seqlens.dtype == torch.int64: - kv_seqlens = kv_seqlens.to(torch.int32) - attn_output = self.flash_mla_fwd(query, - k_cache=k_cache, - block_table=block_offsets, - cache_seqlens=kv_seqlens, - head_dim_v=self.v_head_size, - softmax_scale=self.scale, - tile_scheduler_metadata=attn_metadata.tile_scheduler_metadata, - num_splits=attn_metadata.num_splits, - causal=True) + attn_output = self.flash_mla_decoding(query, k_cache, nsa_indices, attn_metadata) else: - BLOCK_BS = k_cache.size(1) - # pad one more block to avoid invalid kv visit - out_size = (_cdiv(kv_flatten_size, BLOCK_BS) * BLOCK_BS + BLOCK_BS) - flatten_k, flatten_v = self.flatten_kv_cache( - k_cache, - v_cache, - kv_seqlens, - block_offsets, - start_loc=kv_start_loc, - out_size=kv_flatten_size if use_fa3 else out_size, - out_dtype=query.dtype, - k_scales_zeros=k_scales_zeros, - v_scales_zeros=v_scales_zeros, - quant_policy=quant_policy, - flatten_kv_layout='shd' if use_fa3 else 'hsd', - ) - if use_fa3: - q_rope = query[:, :, self.v_head_size:] - q_nope = query[:, :, :self.v_head_size] - k_rope = flatten_k.view(kv_flatten_size, self.num_kv_heads, -1)[:, :, self.v_head_size:] - c_kv = flatten_k.view(kv_flatten_size, self.num_kv_heads, -1)[:, :, :self.v_head_size] - from lmdeploy.pytorch.third_party.flash_attn_interface import flash_attn_varlen_func - attn_output = flash_attn_varlen_func( - q=q_rope, - k=k_rope, - v=c_kv, - qv=q_nope, - cu_seqlens_q=attn_metadata.cu_seqlens_q, - cu_seqlens_k=attn_metadata.cu_seqlens_k, - max_seqlen_q=max_q_seqlen, - max_seqlen_k=kv_flatten_size, - softmax_scale=self.scale, - causal=self.causal, - window_size=(-1, -1) if self.sliding_window is None else self.sliding_window, - softcap=-1.0 if self.logit_softcapping is None else self.logit_softcapping, - ) + flatten_k, flatten_v = self.run_flatten_kv_cache(k_cache, + v_cache, + attn_metadata, + out_dtype=query.dtype, + is_nsa=nsa_indices is not None, + k_scales_zeros=k_scales_zeros, + v_scales_zeros=v_scales_zeros) + if is_nsa: + attn_output = self.flash_mla_prefill(query, flatten_k, nsa_indices, attn_metadata) + elif use_fa3: + attn_output = self.flash_attn_fa3(query, flatten_k, attn_metadata) else: - attn_output = query.new_empty(o_shape) - self.flash_attention_fwd( - query, - flatten_k, - flatten_v, - attn_output, - q_start_loc=q_start_loc, - q_seqlens=q_seqlens, - kv_start_loc=kv_start_loc, - kv_seqlens=kv_seqlens, - max_seqlen=max_q_seqlen, - window_size=self.sliding_window, - sm_scale=self.scale, - logit_softcapping=self.logit_softcapping, - causal=self.causal, - ) + attn_output = self.flash_attn_triton(query, flatten_k, flatten_v, attn_metadata) + return attn_output diff --git a/lmdeploy/pytorch/backends/cuda/graph_runner.py b/lmdeploy/pytorch/backends/cuda/graph_runner.py index deb6c66bfd..947ffe7646 100644 --- a/lmdeploy/pytorch/backends/cuda/graph_runner.py +++ b/lmdeploy/pytorch/backends/cuda/graph_runner.py @@ -82,6 +82,9 @@ def __init__( input_buffers=dict(), output_buffers=dict(), vocab_size=self.model_config.vocab_size, + use_mla_fp8_cache=getattr(self.model_config, 'use_mla_fp8_cache', False), + use_flash_mla=getattr(self.model_config, 'use_flash_mla', False), + mla_index_topk=getattr(self.model_config, 'mla_index_topk', None), ) self.device = device self.max_batches = max_batches diff --git a/lmdeploy/pytorch/backends/cuda/moe.py b/lmdeploy/pytorch/backends/cuda/moe.py index a1394f28bd..b1c232f0e4 100644 --- a/lmdeploy/pytorch/backends/cuda/moe.py +++ b/lmdeploy/pytorch/backends/cuda/moe.py @@ -1,5 +1,6 @@ # Copyright (c) OpenMMLab. All rights reserved. +import contextlib from typing import Callable, List, Optional import torch @@ -575,6 +576,41 @@ def fusedmoe_forward(self, state, up_weight, up_scale, down_weight, down_scale): return hidden_states +@contextlib.contextmanager +def mock_deep_gemm(): + from dlblas.kernels.fused_moe_v3 import use_deep_gemm + if use_deep_gemm: + yield + return + + # patch deep_gemm + import deep_gemm + import dlblas + + from lmdeploy.pytorch.third_party import deep_gemm as patched_deep_gemm + func0_ = getattr(deep_gemm, 'get_col_major_tma_aligned_tensor', None) + func1_ = getattr(deep_gemm, 'm_grouped_gemm_fp8_fp8_bf16_nt_masked', None) + deep_gemm.get_col_major_tma_aligned_tensor = patched_deep_gemm.get_mn_major_tma_aligned_tensor + deep_gemm.m_grouped_gemm_fp8_fp8_bf16_nt_masked = patched_deep_gemm.m_grouped_fp8_gemm_nt_masked + + # patch dlblas + dlblas.kernels.fused_moe_v3.use_deep_gemm = True + dlblas.kernels.fused_moe_v3.m_grouped_gemm_fp8_fp8_bf16_nt_contiguous = \ + patched_deep_gemm.m_grouped_fp8_gemm_nt_contiguous + yield + + # unpatch dlblas + dlblas.kernels.fused_moe_v3.use_deep_gemm = False + + # unpatch deep_gemm + if func0_ is not None: + deep_gemm.get_col_major_tma_aligned_tensor = func0_ + deep_gemm.m_grouped_gemm_fp8_fp8_bf16_nt_masked = func1_ + else: + del deep_gemm.get_col_major_tma_aligned_tensor + del deep_gemm.m_grouped_gemm_fp8_fp8_bf16_nt_masked + + class FusedDeepEpMoEBlockedF8Impl(TritonFusedMoEBlockedF8Impl): def __init__(self, @@ -646,16 +682,35 @@ def do_renormalize(self, topk_weights): def fusedmoe_build(self, low_latency_mode: bool = False): from dlblas.layers.moe.ep_moe import build_deepep_moe - return build_deepep_moe(low_latency_mode, - self.ep_size, - self.ep_group, - self.num_experts, - self.hidden_dim, - self.block_size, - self.top_k, - self.out_dtype, - layer_idx=self.layer_idx, - chunk_size=16 * 1024) + deepep_moe = build_deepep_moe(low_latency_mode, + self.ep_size, + self.ep_group, + self.num_experts, + self.hidden_dim, + self.block_size, + self.top_k, + self.out_dtype, + layer_idx=self.layer_idx, + chunk_size=16 * 1024) + + # patch forward + _origin_forward = deepep_moe.forward + _origin_fusedmoe_forward = deepep_moe.fusedmoe_forward + + def _patched_forward(*args, **kwargs): + with mock_deep_gemm(): + out = _origin_forward(*args, **kwargs) + return out + + def _patched_fusedmoe_forward(*args, **kwargs): + with mock_deep_gemm(): + out = _origin_fusedmoe_forward(*args, **kwargs) + return out + + deepep_moe.forward = _patched_forward + deepep_moe.fusedmoe_forward = _patched_fusedmoe_forward + + return deepep_moe class TritonFusedMoEBlockedF8Builder(FusedMoEBlockedF8Builder): diff --git a/lmdeploy/pytorch/backends/cuda/nsa.py b/lmdeploy/pytorch/backends/cuda/nsa.py new file mode 100644 index 0000000000..ed0c471a39 --- /dev/null +++ b/lmdeploy/pytorch/backends/cuda/nsa.py @@ -0,0 +1,68 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from torch import Tensor + +from lmdeploy.pytorch.kernels.cuda.bitonic_topk import bitonic_topk +from lmdeploy.pytorch.kernels.cuda.blocked_gemm_fp8 import quant_fp8 +from lmdeploy.pytorch.kernels.cuda.ds_index import fp8_index +from lmdeploy.pytorch.kernels.cuda.fill_kv_cache import fill_kv_cache_blocked_fp8 + +from ..nsa import BaseNSAIndexFP8, BaseNSAIndexFP8Builder, NSAIndexMeta + + +class TritonNSAIndexFP8(BaseNSAIndexFP8): + + def __init__(self, topk: int, softmax_scale: float, block_size: int, fill: int) -> None: + super().__init__() + self.topk = topk + self.softmax_scale = softmax_scale + self.block_size = block_size + self.fill = fill + + def forward(self, q: Tensor, k: Tensor, weights: Tensor, k_cache: Tensor, k_s_cache: Tensor, + meta: NSAIndexMeta) -> Tensor: + + assert q.dim() == 3 + assert k.dim() == 2 + cu_seqlen_q = meta.cu_seqlen_q + q_seqlens = meta.q_seqlens + k_seqlens = meta.k_seqlens + block_offset = meta.block_offset + max_q_seqlen = meta.max_q_seqlen + max_kv_seqlen = meta.max_kv_seqlen + + q_shape = q.shape + q = q.reshape(-1, q_shape[-1]) + q, q_s = quant_fp8(q, self.block_size, dtype=k_cache.dtype, trans_scale=True) + q = q.reshape(*q_shape) + q_s = q_s.reshape(weights.shape) + q_s = q_s * self.softmax_scale * weights + + fill_kv_cache_blocked_fp8(k[:, None], + None, + k_cache[..., None, :], + None, + k_s_cache[..., None, :], + None, + cu_seqlen_q=cu_seqlen_q, + kv_seqlens=k_seqlens, + max_q_seqlen=max_q_seqlen, + block_offsets=block_offset, + group_size=self.block_size) + + scores = fp8_index(q, + q_s, + k_cache, + k_s_cache[..., 0], + cu_seqlen_q, + k_seqlens, + block_offset, + max_q_seqlen=max_q_seqlen, + max_k_seqlen=max_kv_seqlen) + return bitonic_topk(scores, q_seqlens, k_seqlens, self.topk, fill=self.fill, descending=True) + + +class TritonNSAIndexFP8Builder(BaseNSAIndexFP8Builder): + + @staticmethod + def build(topk: int, softmax_scale: float, block_size: int = 128, fill: int = -1) -> BaseNSAIndexFP8: + return TritonNSAIndexFP8(topk, softmax_scale=softmax_scale, block_size=block_size, fill=fill) diff --git a/lmdeploy/pytorch/backends/cuda/op_backend.py b/lmdeploy/pytorch/backends/cuda/op_backend.py index d6b77de59e..94b6d507df 100644 --- a/lmdeploy/pytorch/backends/cuda/op_backend.py +++ b/lmdeploy/pytorch/backends/cuda/op_backend.py @@ -12,13 +12,6 @@ logger = get_logger('lmdeploy') -def _get_meta_flashmla(kv_seqlens, num_attention_heads): - """Get meta for flashmla.""" - import flash_mla - tile_scheduler_metadata, num_splits = flash_mla.get_mla_metadata(kv_seqlens.to(torch.int32), num_attention_heads, 1) - return tile_scheduler_metadata, num_splits - - class CudaOpsBackend(DefaultOpsBackend): """Cuda layer backend.""" @@ -72,6 +65,9 @@ def get_layer_impl_builder(cls, layer_type: OpType): elif layer_type == OpType.LinearBlockedF8: from .blockedf8_modules import TritonLinearBlockedF8Builder return TritonLinearBlockedF8Builder + elif layer_type == OpType.NSAIndexFP8: + from .nsa import TritonNSAIndexFP8Builder + return TritonNSAIndexFP8Builder else: logger.debug(f'Op {layer_type} fallback to default implementation.') return super().get_layer_impl_builder(layer_type) @@ -111,10 +107,19 @@ def get_v_block_shape( ) @classmethod - def update_meta_flashmla(cls, attn_metadata, num_attention_heads): + def update_meta_flashmla(cls, attn_metadata, model_config: ModelConfig): """Update meta for flashmla.""" - tile_scheduler_metadata, num_splits = _get_meta_flashmla(attn_metadata.kv_seqlens.to(torch.int32), - num_attention_heads) + import flash_mla + num_attention_heads = model_config.num_attention_heads + is_fp8_kvcache = model_config.use_mla_fp8_cache + index_topk = model_config.mla_index_topk + num_heads_q = None if index_topk is None else num_attention_heads + tile_scheduler_metadata, num_splits = flash_mla.get_mla_metadata(attn_metadata.kv_seqlens.to(torch.int32), + num_attention_heads, + num_heads_k=1, + num_heads_q=num_heads_q, + is_fp8_kvcache=is_fp8_kvcache, + topk=index_topk) attn_metadata.tile_scheduler_metadata = tile_scheduler_metadata attn_metadata.num_splits = num_splits @@ -149,7 +154,8 @@ def update_step_context(cls, step_context): ) if getattr(step_context.model_config, 'use_flash_mla', False) is True: if step_context.is_decoding is True: - cls.update_meta_flashmla(attn_metadata, step_context.model_config.num_attention_heads) + model_config = step_context.model_config + cls.update_meta_flashmla(attn_metadata, model_config) cross_seqlens = step_context.cross_seqlens cross_kv_seqlens = step_context.cross_kv_seqlens diff --git a/lmdeploy/pytorch/backends/dlinfer/attention.py b/lmdeploy/pytorch/backends/dlinfer/attention.py index 228e4de1c4..db92aa39f7 100644 --- a/lmdeploy/pytorch/backends/dlinfer/attention.py +++ b/lmdeploy/pytorch/backends/dlinfer/attention.py @@ -66,6 +66,7 @@ def forward( k_scales_zeros: Tensor = None, v_scales_zeros: Tensor = None, learnable_sink: Tensor = None, + nsa_indices: Tensor = None, inplace: bool = True, ) -> Tensor: """forward.""" diff --git a/lmdeploy/pytorch/backends/nsa.py b/lmdeploy/pytorch/backends/nsa.py new file mode 100644 index 0000000000..c20c80868c --- /dev/null +++ b/lmdeploy/pytorch/backends/nsa.py @@ -0,0 +1,34 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from abc import ABC, abstractmethod +from dataclasses import dataclass + +from torch import Tensor + + +@dataclass +class NSAIndexMeta: + """Meta info of NSAIndex layer.""" + cu_seqlen_q: Tensor + q_seqlens: Tensor + k_seqlens: Tensor + block_offset: Tensor + max_q_seqlen: int = None + max_kv_seqlen: int = None + + +class BaseNSAIndexFP8(ABC): + + @abstractmethod + def forward(self, q: Tensor, k: Tensor, weights: Tensor, k_cache: Tensor, k_s_cache: Tensor, + meta: NSAIndexMeta) -> Tensor: + """forward.""" + raise NotImplementedError('Not implemented.') + + +class BaseNSAIndexFP8Builder: + + @staticmethod + @abstractmethod + def build(topk: int, softmax_scale: float, block_size: int = 128, fill: int = -1) -> BaseNSAIndexFP8: + """Build layer implementation.""" + raise NotImplementedError('Not implemented.') diff --git a/lmdeploy/pytorch/check_env/model.py b/lmdeploy/pytorch/check_env/model.py index d1d096dbfe..0e5884b69b 100644 --- a/lmdeploy/pytorch/check_env/model.py +++ b/lmdeploy/pytorch/check_env/model.py @@ -19,8 +19,8 @@ def check_config(self, trans_version): model_path = self.model_path trust_remote_code = self.trust_remote_code try: - from transformers import AutoConfig - config = AutoConfig.from_pretrained(model_path, trust_remote_code=trust_remote_code) + from lmdeploy.pytorch.transformers import config_from_pretrained + config = config_from_pretrained(model_path, trust_remote_code=trust_remote_code) except Exception as e: message = (f'Load model config with transformers=={trans_version}' ' failed. ' @@ -57,7 +57,13 @@ def check_dtype(self, config): if not is_bf16_supported(device_type): logger.warning('Device does not support bfloat16.') except Exception as e: - message = (f'Checking failed with error {e}', 'Please send issue to LMDeploy with error logs.') + message = (f'Checking failed with error {e}. Please send issue to LMDeploy with error logs.') + self.log_and_exit(e, 'Model', message=message) + + try: + model_config.check_env_func(device_type) + except Exception as e: + message = (f'Checking failed with error {e}.') self.log_and_exit(e, 'Model', message=message) def check(self): diff --git a/lmdeploy/pytorch/config.py b/lmdeploy/pytorch/config.py index ac3459e045..403d07fb37 100644 --- a/lmdeploy/pytorch/config.py +++ b/lmdeploy/pytorch/config.py @@ -1,7 +1,7 @@ # Copyright (c) OpenMMLab. All rights reserved. import enum -from dataclasses import dataclass -from typing import Any, Dict, List, Literal +from dataclasses import dataclass, field +from typing import Any, Callable, Dict, List, Literal, Optional, Tuple import torch @@ -183,6 +183,10 @@ def override_hf_config(hf_config: Any, hf_overrides: Dict[str, Any]): _override_hf_config(hf_config, k, v) +def _default_check_env(device: str): + pass + + @dataclass class ModelConfig: """Config of model.""" @@ -203,11 +207,24 @@ class ModelConfig: llm_config: Any = None cogvlm_style: bool = False custom_module_map: Dict[str, setattr] = None + + # flash mla use_flash_mla: bool = False + use_mla_fp8_cache: bool = False + mla_index_topk: Optional[int] = None + + # dllm model_paradigm: str = 'ar' dllm_mask_token: int = 0 dllm_block_length: int = None + # Added for deepseekv3.2 nsa index + # caches would be added after kv cache + cache_shapes: List[Tuple[List[int], torch.dtype]] = field(default_factory=list) + + # check env for model-device combination + check_env_func: Callable = _default_check_env + def get_head_size(self): """Get head size.""" return self.head_dim @@ -232,9 +249,9 @@ def from_pretrained(cls, """ from transformers import AutoConfig + from lmdeploy.pytorch.transformers import config_from_pretrained from lmdeploy.utils import get_logger - - hf_config = AutoConfig.from_pretrained(pretrained_model_name_or_path, trust_remote_code=trust_remote_code) + hf_config = config_from_pretrained(pretrained_model_name_or_path, trust_remote_code=trust_remote_code) if getattr(hf_config, 'model_type', None) in ['phi3']: # phi3 + trust_remote_code leads to error when tp. hf_config = AutoConfig.from_pretrained(pretrained_model_name_or_path) diff --git a/lmdeploy/pytorch/configurations/deepseek_v2.py b/lmdeploy/pytorch/configurations/deepseek_v2.py index f83abe38f2..f0eeb344ec 100644 --- a/lmdeploy/pytorch/configurations/deepseek_v2.py +++ b/lmdeploy/pytorch/configurations/deepseek_v2.py @@ -1,16 +1,38 @@ # Copyright (c) OpenMMLab. All rights reserved. +import torch + from lmdeploy.pytorch.config import ModelConfig from .builder import AutoModelConfigBuilder from .utils import flash_mla_available +def _check_env_v32(device: str = 'cuda'): + """Environment check.""" + if device != 'cuda': + return + + # check cuda + try: + import fast_hadamard_transform # noqa: F401 + except ImportError: + raise ImportError('Deepseek V3.2 requires .') + + try: + import flash_mla # noqa: F401 + except ImportError: + raise ImportError('Deepseek V3.2 requires .') + + if not hasattr(flash_mla, 'flash_mla_sparse_fwd'): + raise RuntimeError('Deepseek V3.2 latest with support.') + + class DeepseekV2ModelConfigBuilder(AutoModelConfigBuilder): @classmethod def condition(cls, hf_config): """config.""" - return hf_config.model_type in ['deepseek_v3', 'deepseek_v2', 'kimi_k2'] + return hf_config.model_type in ['deepseek_v3', 'deepseek_v2', 'deepseek_v32', 'kimi_k2'] @classmethod def build(cls, hf_config, model_path: str = None, **kwargs): @@ -26,14 +48,24 @@ def build(cls, hf_config, model_path: str = None, **kwargs): num_key_value_heads = cls.update_num_kv_heads(hf_config, tp, num_key_value_heads) hf_config.use_flash_mla = flash_mla_available() - return ModelConfig(hidden_size=hf_config.hidden_size, - num_layers=hf_config.num_hidden_layers, - num_attention_heads=num_attention_heads, - num_key_value_heads=num_key_value_heads, - bos_token_id=hf_config.bos_token_id, - eos_token_id=hf_config.eos_token_id, - head_dim=head_dim, - k_head_dim=k_head_dim, - v_head_dim=v_head_dim, - vocab_size=hf_config.vocab_size, - use_flash_mla=hf_config.use_flash_mla) + config = ModelConfig(hidden_size=hf_config.hidden_size, + num_layers=hf_config.num_hidden_layers, + num_attention_heads=num_attention_heads, + num_key_value_heads=num_key_value_heads, + bos_token_id=hf_config.bos_token_id, + eos_token_id=hf_config.eos_token_id, + head_dim=head_dim, + k_head_dim=k_head_dim, + v_head_dim=v_head_dim, + vocab_size=hf_config.vocab_size, + use_flash_mla=hf_config.use_flash_mla) + + if hf_config.model_type == 'deepseek_v32': + assert hf_config.use_flash_mla, 'DeepSeek-V3.2 requires flash_mla to be available.' + index_k_shape = ([hf_config.index_head_dim], torch.float8_e4m3fn) + index_k_scale_shape = ([1], torch.float32) + config.cache_shapes = [index_k_shape, index_k_scale_shape] + config.use_mla_fp8_cache = True + config.mla_index_topk = hf_config.index_topk + config.check_env_func = _check_env_v32 + return config diff --git a/lmdeploy/pytorch/engine/cache_engine.py b/lmdeploy/pytorch/engine/cache_engine.py index d8ec198349..bbe18cd7af 100644 --- a/lmdeploy/pytorch/engine/cache_engine.py +++ b/lmdeploy/pytorch/engine/cache_engine.py @@ -1,7 +1,9 @@ # Copyright (c) OpenMMLab. All rights reserved. # modify from: https://github.com/vllm-project/vllm import json -from typing import Dict, List, Literal, Optional, Tuple +import math +from dataclasses import dataclass +from typing import Dict, List, Literal, Optional, Sequence, Tuple import torch @@ -20,6 +22,35 @@ logger = get_logger('lmdeploy') +def round_up(x: int, alignment: int) -> int: + """Round up x to the nearest multiple of alignment.""" + return ((x + alignment - 1) // alignment) * alignment + + +@dataclass +class CacheDesc: + """Cache description.""" + shape: List[int] + dtype: torch.dtype + alignment: int = 256 + + def __post_init__(self): + self.numel = math.prod(self.shape) + self.size = self.numel * self.dtype.itemsize + self.aligned_size = round_up(self.size, self.alignment) + + +def _get_kv_cache_dtype(model_config: ModelConfig): + kv_cache_dtype = model_config.dtype + if model_config.use_mla_fp8_cache: + kv_cache_dtype = torch.float8_e4m3fn + return kv_cache_dtype + + +# 512*1 + 4*4 + 64*2 = 656 +MLA_FP8_HEAD_DIM = 656 + + class CacheEngine: """Host and Device memory maintainer. @@ -50,7 +81,11 @@ def __init__( self.block_size = cache_config.block_size self.num_layers = model_config.num_layers - self.kv_cache_dtype = model_config.dtype + self.kv_cache_dtype = _get_kv_cache_dtype(self.model_config) + + if self.model_config.use_mla_fp8_cache: + cache_config.quant_policy = 0 + if cache_config.quant_policy > 0: if self.cache_config.device_type in ['cuda']: self.kv_cache_dtype = torch.uint8 @@ -100,16 +135,21 @@ def _get_key_block_shape_impl(cls, block_size: int, head_size: int, world_size: int = 1, - quant_policy: Literal[0, 4, 8] = 0, - local: bool = True): + quant_policy: Literal[0, 4, 8] = 0): """Get single block shape.""" attn_backend = get_backend() dtype = model_config.dtype num_heads = model_config.num_key_value_heads - if local: - assert num_heads % world_size == 0, \ - f'num_heads: {num_heads}, world_size: {world_size}' - num_heads = num_heads // world_size + + # split heads by tp + assert num_heads % world_size == 0, \ + f'num_heads: {num_heads}, world_size: {world_size}' + num_heads = num_heads // world_size + + # patch for flash mla + if model_config.use_mla_fp8_cache: + return (block_size, num_heads, MLA_FP8_HEAD_DIM) + if quant_policy == 4: # pack head_dim to uint8 assert head_size % 2 == 0, \ f'head_size: {head_size}, quant_policy: {quant_policy}' @@ -122,16 +162,22 @@ def _get_value_block_shape_impl(cls, block_size: int, head_size: int, world_size: int = 1, - quant_policy: Literal[0, 4, 8] = 0, - local: bool = True): + quant_policy: Literal[0, 4, 8] = 0): """Get single block shape.""" attn_backend = get_backend() dtype = model_config.dtype num_heads = model_config.num_key_value_heads - if local: - assert num_heads % world_size == 0, \ - f'num_heads: {num_heads}, world_size: {world_size}' - num_heads = num_heads // world_size + + # split heads by tp + assert num_heads % world_size == 0, \ + f'num_heads: {num_heads}, world_size: {world_size}' + num_heads = num_heads // world_size + + # patch for flash mla + if model_config.use_mla_fp8_cache: + # flash mla shared key and value + return (block_size, num_heads, 0) + if quant_policy == 4: # pack head_dim to uint8 assert head_size % 2 == 0, \ f'head_size: {head_size}, quant_policy: {quant_policy}' @@ -139,86 +185,155 @@ def _get_value_block_shape_impl(cls, return attn_backend.get_v_block_shape(block_size, num_heads, head_size, dtype) - def get_key_block_shape(self, local: bool = False) -> Tuple[int, int, int]: - """Get shape of key block.""" - head_size = self.model_config.k_head_dim + @classmethod + def get_k_cache_desc(cls, model_config: ModelConfig, cache_config: CacheConfig, world_size: int = 1) -> CacheDesc: + """Get key cache description.""" + head_size = model_config.k_head_dim if head_size is None: - head_size = self.model_config.head_dim - return self._get_key_block_shape_impl( - self.model_config, - block_size=self.block_size, + head_size = model_config.head_dim + shape = cls._get_key_block_shape_impl( + model_config, + block_size=cache_config.block_size, head_size=head_size, - world_size=self.world_size, - quant_policy=self.cache_config.quant_policy, - local=local, + world_size=world_size, + quant_policy=cache_config.quant_policy, ) + shape = list(shape) + dtype = _get_kv_cache_dtype(model_config) + if cache_config.quant_policy in (4, 8): + dtype = torch.uint8 + return CacheDesc(shape=shape, dtype=dtype) - def get_value_block_shape(self, local: bool = False) -> Tuple[int, int, int]: - """Get shape of value block.""" - head_size = self.model_config.v_head_dim + @classmethod + def get_v_cache_desc(cls, model_config: ModelConfig, cache_config: CacheConfig, world_size: int = 1) -> CacheDesc: + """Get value cache description.""" + head_size = model_config.v_head_dim if head_size is None: - head_size = self.model_config.head_dim - return self._get_value_block_shape_impl( - self.model_config, - block_size=self.block_size, + head_size = model_config.head_dim + shape = cls._get_value_block_shape_impl( + model_config, + block_size=cache_config.block_size, head_size=head_size, - world_size=self.world_size, - quant_policy=self.cache_config.quant_policy, - local=local, + world_size=world_size, + quant_policy=cache_config.quant_policy, ) + shape = list(shape) + dtype = _get_kv_cache_dtype(model_config) + if cache_config.quant_policy in (4, 8): + dtype = torch.uint8 + return CacheDesc(shape=shape, dtype=dtype) - def _allocate_cache(self, num_blocks: int, device: torch.device): - """Allocate cache implement.""" - key_block_shape = self.get_key_block_shape(local=True) - value_block_shape = self.get_value_block_shape(local=True) + @classmethod + def get_quant_cache_descs(cls, k_cache_desc: CacheDesc, v_cache_desc: CacheDesc, model_config: ModelConfig, + cache_config: CacheConfig): + """Get quant cache descs.""" + if cache_config.quant_policy == 0: + return [] - num_layers = self.num_layers - kv_cache_dtype = self.kv_cache_dtype + dtype = model_config.dtype + key_scale_zero_shape = k_cache_desc.shape[:-1] + [2] + val_scale_zero_shape = v_cache_desc.shape[:-1] + [2] + key_scale_zero_desc = CacheDesc(shape=key_scale_zero_shape, dtype=dtype) + val_scale_zero_desc = CacheDesc(shape=val_scale_zero_shape, dtype=dtype) + return [key_scale_zero_desc, val_scale_zero_desc] - key_cache = torch.empty( - size=(num_layers, num_blocks, *key_block_shape), - dtype=kv_cache_dtype, - device=device, - ) - value_cache = torch.empty( - size=(num_layers, num_blocks, *value_block_shape), - dtype=kv_cache_dtype, - device=device, - ) + @classmethod + def get_custom_cache_descs(cls, model_config: ModelConfig, cache_config: CacheConfig) -> List[CacheDesc]: + """Get custom cache descs.""" + if len(model_config.cache_shapes) == 0: + return [] - output = (key_cache, value_cache) + block_size = cache_config.block_size - if self.cache_config.quant_policy in (4, 8): - dtype = self.model_config.dtype - key_sz_cache = torch.empty( - size=(num_layers, num_blocks, *key_block_shape[:-1], 2), - dtype=dtype, - device=device, - ) - val_sz_cache = torch.empty( - size=(num_layers, num_blocks, *value_block_shape[:-1], 2), - dtype=dtype, - device=device, - ) - output = output + (key_sz_cache, val_sz_cache) + descs = [] + for shape, dtype in model_config.cache_shapes: + custom_shape = (block_size, *shape) + desc = CacheDesc(shape=custom_shape, dtype=dtype) + descs.append(desc) + return descs - return output + @classmethod + def allocate_caches(cls, num_blocks: int, model_config: ModelConfig, cache_config: CacheConfig, world_size: int, + device: str): + """Allocate caches.""" + + num_layers = model_config.num_layers + + # get all descs + k_cache_desc = cls.get_k_cache_desc(model_config, cache_config, world_size) + v_cache_desc = cls.get_v_cache_desc(model_config, cache_config, world_size) + quant_cache_descs = cls.get_quant_cache_descs(k_cache_desc, v_cache_desc, model_config, cache_config) + custom_cache_descs = cls.get_custom_cache_descs(model_config, cache_config) + cache_descs = [k_cache_desc, v_cache_desc] + quant_cache_descs + custom_cache_descs + + # get mempool size + mem_pool_size = 0 + for desc in cache_descs: + mem_pool_size += desc.aligned_size + + # create pool + mem_pool = torch.zeros((num_layers, num_blocks, mem_pool_size), dtype=torch.uint8, device=device) + + # slice caches + caches = [] + remain_pool = mem_pool + for desc in cache_descs: + cache = remain_pool[:, :, :desc.size].view(desc.dtype).view((num_layers, num_blocks, *desc.shape)) + remain_pool = remain_pool[:, :, desc.aligned_size:] + caches.append(cache) + return mem_pool, caches def allocate_gpu_cache(self): """Allocate caches on GPU.""" - caches = self._allocate_cache(self.num_gpu_blocks, 'cuda') - self.full_gpu_cache = caches + mem_pool, caches = self.allocate_caches( + num_blocks=self.num_gpu_blocks, + model_config=self.model_config, + cache_config=self.cache_config, + world_size=self.world_size, + device='cuda', + ) + self.full_gpu_cache = mem_pool self.local_gpu_cache = list(zip(*caches)) return self.local_gpu_cache def allocate_cpu_cache(self): """Allocate caches on Host.""" - caches = self._allocate_cache(self.num_cpu_blocks, 'cpu') - - self.full_cpu_cache = caches + mem_pool, caches = self.allocate_caches( + num_blocks=self.num_cpu_blocks, + model_config=self.model_config, + cache_config=self.cache_config, + world_size=self.world_size, + device='cpu', + ) + self.full_cpu_cache = mem_pool self.local_cpu_cache = list(zip(*caches)) return self.local_cpu_cache + @staticmethod + def get_custom_cache_shape_impl(num_layers: int, num_blocks: int, block_size: int, shape: List[int]): + """Get single block shape.""" + return (num_layers, num_blocks, block_size, *shape) + + @staticmethod + def _allocate_single_custom_cache(shape: Sequence[int], dtype: torch.dtype, device: str): + """Allocate custom cache.""" + return torch.empty(shape, dtype=dtype, device=device) + + def allocate_custom_cache(self, device: str): + """Allocate custom caches on GPU.""" + num_layers = self.model_config.num_layers + custom_caches = [] + for shape, dtype in self.model_config.cache_shapes: + custom_shape = self.get_custom_cache_shape_impl( + num_layers=num_layers, + num_blocks=self.num_gpu_blocks, + block_size=self.block_size, + shape=shape, + ) + custom_cache = self._allocate_single_custom_cache(shape=custom_shape, dtype=dtype, device=device) + custom_caches.append(custom_cache) + return custom_caches + @torch.inference_mode() def _swap(self, src: List[torch.Tensor], dst: List[torch.Tensor], src_to_dst: Dict[int, int]): """Move caches from src memory to dst memory. @@ -248,7 +363,7 @@ def swap_in(self, src_to_dst: Dict[int, int]) -> None: Args: src_to_dst (Dict[int, int]): Map between src and dst. """ - self._swap(self.full_cpu_cache, self.full_gpu_cache, src_to_dst) + self._swap([self.full_cpu_cache], [self.full_gpu_cache], src_to_dst) def swap_out(self, src_to_dst: Dict[int, int]) -> None: """Move cache from Device to Host. @@ -256,14 +371,10 @@ def swap_out(self, src_to_dst: Dict[int, int]) -> None: Args: src_to_dst (Dict[int, int]): Map between src and dst. """ - self._swap(self.full_gpu_cache, self.full_cpu_cache, src_to_dst) + self._swap([self.full_gpu_cache], [self.full_cpu_cache], src_to_dst) @classmethod - def get_cache_block_size(cls, - block_size: int, - model_config: ModelConfig, - world_size: int = 1, - quant_policy: int = 0) -> int: + def get_cache_block_size(cls, cache_config: CacheConfig, model_config: ModelConfig, world_size: int = 1) -> int: """Get the required cache size of the model. Args: @@ -273,49 +384,15 @@ def get_cache_block_size(cls, Return: int: Required memory size in bytes. """ - num_layers = model_config.num_layers - key_head_size = model_config.k_head_dim - value_head_size = model_config.v_head_dim - if key_head_size is None: - key_head_size = model_config.head_dim - if value_head_size is None: - value_head_size = model_config.head_dim - key_shape = cls._get_key_block_shape_impl( - model_config, - block_size=block_size, - head_size=key_head_size, - world_size=world_size, - local=True, - quant_policy=quant_policy, - ) - value_shape = cls._get_value_block_shape_impl( - model_config, - block_size=block_size, - head_size=value_head_size, + mem_pool, _ = cls.allocate_caches( + num_blocks=1, + model_config=model_config, + cache_config=cache_config, world_size=world_size, - quant_policy=quant_policy, - local=True, + device='meta', ) - if quant_policy == 0: - dtype = model_config.dtype - key_block = torch.empty(key_shape, dtype=dtype, device='meta') - value_block = torch.empty(value_shape, dtype=dtype, device='meta') - mem_key_block = key_block.numel() * key_block.element_size() - mem_value_block = value_block.numel() * value_block.element_size() - elif quant_policy in (4, 8): - key_block = torch.empty(key_shape, dtype=torch.uint8, device='meta') - value_block = torch.empty(value_shape, dtype=torch.uint8, device='meta') - key_scale_zero_block = torch.empty((*key_shape[:-1], 2), dtype=model_config.dtype, device='meta') - value_scale_zero_block = torch.empty((*value_shape[:-1], 2), dtype=model_config.dtype, device='meta') - mem_key_block = key_block.numel() * key_block.element_size() + key_scale_zero_block.numel( - ) * key_scale_zero_block.element_size() - mem_value_block = value_block.numel() * value_block.element_size() + value_scale_zero_block.numel( - ) * value_scale_zero_block.element_size() - else: - raise ValueError(f'unsupported quant_policy {quant_policy}') - - total = num_layers * (mem_key_block + mem_value_block) - return total + + return mem_pool.numel() * mem_pool.element_size() """ Metheds for PD Disaggregation Begin. """ @@ -324,7 +401,7 @@ def p2p_initialize(self, migration_init_request: DistServeInitRequest) -> DistSe self.migration_backend_impl = MIGRATION_BACKENDS.module_dict[self.cache_config.migration_backend.name]() migration_init_request.rank = self.rank self.migration_backend_impl.p2p_initialize(migration_init_request) - for i, t in enumerate(self.full_gpu_cache): + for i, t in enumerate([self.full_gpu_cache]): if t.numel() == 0: continue register_mr_request = DistServeRegisterMRMessage(protocol=migration_init_request.protocol, @@ -370,7 +447,7 @@ def get_assignment_batch(mr_key, block_ids, assignment_len, layer_stride, remote remote_layer_stride = self.migration_backend_impl.links[ remote_engine_id].remote_engine_config.num_gpu_blocks * assignment_len - for i, t in enumerate(self.full_gpu_cache): + for i, t in enumerate([self.full_gpu_cache]): if t.numel() == 0: continue assignment_batch.extend( diff --git a/lmdeploy/pytorch/engine/executor/base.py b/lmdeploy/pytorch/engine/executor/base.py index 9e50843a80..0c748c29f9 100644 --- a/lmdeploy/pytorch/engine/executor/base.py +++ b/lmdeploy/pytorch/engine/executor/base.py @@ -164,8 +164,7 @@ def update_configs(self): vocal_size = self.model_config.vocab_size tp = self.dist_config.attn_config.tp - cache_block_size = CacheEngine.get_cache_block_size(cache_config.block_size, model_config, tp, - cache_config.quant_policy) + cache_block_size = CacheEngine.get_cache_block_size(cache_config, model_config, tp) runtime_mem, max_prefill_token_num = self._get_runtime_size(free_mem, cache_block_size, vocal_size) if cache_config.max_prefill_token_num != max_prefill_token_num: if max_prefill_token_num <= 0: diff --git a/lmdeploy/pytorch/engine/model_agent.py b/lmdeploy/pytorch/engine/model_agent.py index 2ba2850c75..a48153f08a 100644 --- a/lmdeploy/pytorch/engine/model_agent.py +++ b/lmdeploy/pytorch/engine/model_agent.py @@ -232,6 +232,7 @@ def model_forward( context = ctx_mgr.build_context( inputs=inputs, model_config=cache_engine.model_config, + cache_config=cache_engine.cache_config, kv_caches=cache_engine.gpu_cache, kv_quant_policy=cache_engine.cache_config.quant_policy, ) diff --git a/lmdeploy/pytorch/kernels/cuda/__init__.py b/lmdeploy/pytorch/kernels/cuda/__init__.py index f4ae57714b..006ddada66 100644 --- a/lmdeploy/pytorch/kernels/cuda/__init__.py +++ b/lmdeploy/pytorch/kernels/cuda/__init__.py @@ -3,7 +3,6 @@ from .alibi_pagedattention import alibi_paged_attention_fwd from .apply_rotary_pos_emb import apply_rotary_pos_emb from .fill_kv_cache import fill_kv_cache -from .flash_mla import flash_mla_fwd from .flashattention import flash_attention_fwd from .flatten_kv_cache import flatten_kv_cache from .fused_moe import fused_moe diff --git a/lmdeploy/pytorch/kernels/cuda/bitonic_topk.py b/lmdeploy/pytorch/kernels/cuda/bitonic_topk.py new file mode 100644 index 0000000000..3a30d6e44f --- /dev/null +++ b/lmdeploy/pytorch/kernels/cuda/bitonic_topk.py @@ -0,0 +1,215 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import torch +import triton +import triton.language as tl +from triton.language import core +from triton.language.standard import _log2 + + +@triton.jit +def _compare_and_swap(x, ids, flip, i: tl.constexpr, n_dims: tl.constexpr): + n_outer: tl.constexpr = x.numel >> n_dims + shape: tl.constexpr = [n_outer * 2**i, 2, 2**(n_dims - i - 1)] + y = tl.reshape(x, shape) + # slice left/right with 'stride' 2**(n_dims - i - 1) + mask = tl.arange(0, 2)[None, :, None] + left = tl.broadcast_to(tl.sum(y * (1 - mask), 1)[:, None, :], shape) + right = tl.broadcast_to(tl.sum(y * mask, 1)[:, None, :], shape) + left = tl.reshape(left, x.shape) + right = tl.reshape(right, x.shape) + + # idx + y_idx = tl.reshape(ids, shape) + left_idx = tl.broadcast_to(tl.sum(y_idx * (1 - mask), 1)[:, None, :], shape) + right_idx = tl.broadcast_to(tl.sum(y_idx * mask, 1)[:, None, :], shape) + left_idx = tl.reshape(left_idx, x.shape) + right_idx = tl.reshape(right_idx, x.shape) + + # actual compare-and-swap + idtype = core.get_int_dtype(bitwidth=x.dtype.primitive_bitwidth, signed=True) + ileft = left.to(idtype, bitcast=True) + iright = right.to(idtype, bitcast=True) + ix = x.to(idtype, bitcast=True) + + cond = (left > right) ^ flip + cond = cond.to(tl.int1) + + ret = ix ^ tl.where(cond, ileft ^ iright, tl.zeros_like(ix)) + + new_ids = ids ^ tl.where(cond, left_idx ^ right_idx, tl.zeros_like(ids)) + + return ret.to(x.dtype, bitcast=True), new_ids + + +@triton.jit +def _bitonic_merge(x, ids, stage: tl.constexpr, order: tl.constexpr, n_dims: tl.constexpr): + """order_type 0 == ascending order_type 1 == descending order_type 2 == + alternating.""" + n_outer: tl.constexpr = x.numel >> n_dims + tl.static_assert(stage <= n_dims) + # flip denotes whether to re-arrange sub-sequences of elements in ascending or + # descending order. + # if flip = 00000000... then all elements will be re-arranged ascendingly at this stage + # if flip = 00110011... then all the elements will be re-arranged alternatingly (with + # a stride of 2) at this stage + if order == 2: + shape: tl.constexpr = [n_outer * 2**(n_dims - 1 - stage), 2, 2**stage] + flip = tl.reshape(tl.broadcast_to(tl.arange(0, 2)[None, :, None], shape), x.shape) + else: + flip = order + # perform `stage` rounds of `compare-and-swap` + for i in tl.static_range(stage): + x, ids = _compare_and_swap(x, ids, flip, i + (n_dims - stage), n_dims) + return x, ids + + +@triton.jit +def argsort(x, ids, dim: tl.constexpr = None, descending: tl.constexpr = core.CONSTEXPR_0): + # handle default dimension or check that it is the most minor dim + _dim: tl.constexpr = len(x.shape) - 1 if dim is None else dim + tl.static_assert(_dim == len(x.shape) - 1, 'only minor dimension is currently supported') + # iteratively run bitonic merge-sort steps + n_dims: tl.constexpr = _log2(x.shape[_dim]) + + for i in tl.static_range(1, n_dims + 1): + x, ids = _bitonic_merge(x, ids, i, 2 if i < n_dims else descending, n_dims) + return x, ids + + +@triton.jit +def _bitonic_topk_kernel0(score_ptr, + seqlen_ptr, + out_ptr, + ids_ptr, + stride_m: tl.constexpr, + K: tl.constexpr, + fill: tl.constexpr, + descending: tl.constexpr = core.CONSTEXPR_0): + """kernel0.""" + batch_id = tl.program_id(0).to(tl.int64) + block_id = tl.program_id(1).to(tl.int64) + + seqlen = tl.load(seqlen_ptr + batch_id) + + if block_id * K >= seqlen: + return + + offs_k = tl.arange(0, K) + origin_ids = block_id * K + offs_k + mask = (origin_ids < seqlen) + score_ptrs = score_ptr + batch_id * stride_m + origin_ids + scores = tl.load(score_ptrs, mask=mask, other=-1e6) + ids = tl.where(mask, origin_ids, fill) + ids = origin_ids + + scores, ids = argsort(scores, ids, 0, descending) + + tl.store(out_ptr + batch_id * stride_m + origin_ids, scores, mask=mask) + tl.store(ids_ptr + batch_id * stride_m + origin_ids, ids, mask=mask) + + +@triton.jit +def _concate(a, b): + """concate.""" + c = tl.join(a, b) # [k, 2] + c = c.trans() # [2, k] + # there are bugs in `tr.ravel` when triton<=3.2.0 + c = tl.reshape(c, (a.numel + b.numel, )) + return c + + +@triton.jit +def _split(a, k): + """split.""" + a = a.reshape(2, k) + a = a.trans() + return tl.split(a) + + +@triton.jit +def _bitonic_topk_kernel1(score_ptr, + ids_ptr, + seqlen_ptr, + out_ptr, + stride_m: tl.constexpr, + K: tl.constexpr, + fill: tl.constexpr, + descending: tl.constexpr = core.CONSTEXPR_0): + """kernel1.""" + batch_id = tl.program_id(0).to(tl.int64) + + seqlen = tl.load(seqlen_ptr + batch_id) + offs_k = tl.arange(0, K) + score_ptrs = score_ptr + batch_id * stride_m + offs_k + ids_ptrs = ids_ptr + batch_id * stride_m + offs_k + + # initialize + pos = offs_k + mask = pos < seqlen + scores = tl.load(score_ptrs, mask=mask, other=-1e6) + ids = tl.load(ids_ptrs, mask=mask, other=fill) + + pos = 2 * K - 1 - offs_k + score_ptrs = score_ptr + batch_id * stride_m + pos + ids_ptrs = ids_ptr + batch_id * stride_m + pos + + stage: tl.constexpr = _log2(2 * K) + for k in tl.range(K, seqlen, K, num_stages=3): + mask = pos < seqlen + new_scores = tl.load(score_ptrs, mask=mask, other=-1e6) + new_ids = tl.load(ids_ptrs, mask=mask, other=fill) + + merged_scores = _concate(scores, new_scores) + merged_ids = _concate(ids, new_ids) + + merged_scores, merged_ids = _bitonic_merge(merged_scores, merged_ids, stage, descending, stage) + # merged_scores, merged_ids = argsort(merged_scores, merged_ids, 0, descending) + + scores, _ = _split(merged_scores, K) + ids, _ = _split(merged_ids, K) + score_ptrs += K + ids_ptrs += K + pos += K + + out_ptrs = out_ptr + batch_id * K + offs_k + tl.store(out_ptrs, ids) + + +def bitonic_topk(scores: torch.Tensor, + q_seqlens: torch.Tensor, + kv_seqlens: torch.Tensor, + k: int, + fill: int = -1, + descending: bool = True): + """Bitnoic topk.""" + num_tokens = scores.size(0) + max_kv_len = scores.size(-1) + + if num_tokens != kv_seqlens.size(0): + repeat_kv_seqlens = torch.repeat_interleave(kv_seqlens, q_seqlens, output_size=num_tokens) + else: + repeat_kv_seqlens = kv_seqlens + tmp_scores = torch.empty_like(scores) + tmp_ids = torch.empty_like(scores, dtype=torch.int32) + grid = (num_tokens, triton.cdiv(max_kv_len, k)) + _bitonic_topk_kernel0[grid](scores, + repeat_kv_seqlens, + tmp_scores, + tmp_ids, + stride_m=scores.stride(0), + K=k, + fill=fill, + descending=1 if descending else 0, + num_warps=4) + + out = kv_seqlens.new_empty((num_tokens, k), dtype=torch.int32) + _bitonic_topk_kernel1[(num_tokens, )](tmp_scores, + tmp_ids, + repeat_kv_seqlens, + out, + stride_m=tmp_scores.stride(0), + K=k, + fill=fill, + descending=1 if descending else 0, + num_warps=4) + return out diff --git a/lmdeploy/pytorch/kernels/cuda/ds_index.py b/lmdeploy/pytorch/kernels/cuda/ds_index.py new file mode 100644 index 0000000000..51af0bc8e9 --- /dev/null +++ b/lmdeploy/pytorch/kernels/cuda/ds_index.py @@ -0,0 +1,167 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import torch +import triton +import triton.language as tl + +from .utils import get_device_props + + +@triton.jit +def _fp8_index_kernel( + q_ptr, + q_s_ptr, + k_cache_ptr, + k_s_cache_ptr, + cu_seqlen_q_ptr, + k_seqlen_ptr, + block_offset_ptr, + out_ptr, + stride_qm: tl.constexpr, + stride_qh: tl.constexpr, + stride_qd: tl.constexpr, + stride_qsm: tl.constexpr, + stride_qsh: tl.constexpr, + stride_kb: tl.constexpr, + stride_kn: tl.constexpr, + stride_kd: tl.constexpr, + stride_ksb: tl.constexpr, + stride_ksn: tl.constexpr, + stride_boff0, + stride_boff1: tl.constexpr, + stride_om: tl.constexpr, + stride_on: tl.constexpr, + max_q_seqlen, + BLOCK_H: tl.constexpr, + BLOCK_N: tl.constexpr, + BLOCK_D: tl.constexpr, + NUM_SPLIT: tl.constexpr, +): + """Fp8 index kernel.""" + m_id = tl.program_id(0).to(tl.int64) + split_id = tl.program_id(1).to(tl.int64) + + assert stride_qd == 1 + assert stride_kd == 1 + + batch_id = m_id // max_q_seqlen + q_id = m_id % max_q_seqlen + q_start = tl.load(cu_seqlen_q_ptr + batch_id) + q_seqlen = tl.load(cu_seqlen_q_ptr + batch_id + 1) - q_start + if q_id >= q_seqlen: + return + + k_seqlen = tl.load(k_seqlen_ptr + batch_id) + if k_seqlen <= 0: + return + + q_pos = q_start + q_id + offs_h = tl.arange(0, BLOCK_H) + offs_d = tl.arange(0, BLOCK_D) + offs_n = tl.arange(0, BLOCK_N) + + q_ptrs = q_ptr + q_pos * stride_qm + offs_h[:, None] * stride_qh + offs_d[None, :] * stride_qd + q_s_ptrs = q_s_ptr + q_pos * stride_qsm + offs_h * stride_qsh + q = tl.load(q_ptrs) + q_s = tl.load(q_s_ptrs) + + k_ptrs = k_cache_ptr + offs_n[None, :] * stride_kn + offs_d[:, None] * stride_kd + k_s_ptrs = k_s_cache_ptr + offs_n * stride_ksn + o_ptrs = out_ptr + q_pos * stride_om + offs_n * stride_on + split_id * BLOCK_N * stride_on + boff_ptr = block_offset_ptr + batch_id * stride_boff0 + split_id * stride_boff1 + + num_blocks = tl.cdiv(k_seqlen, BLOCK_N) + for boff_id in tl.range(split_id, num_blocks, NUM_SPLIT, num_stages=3): + boff = tl.load(boff_ptr) + + k = tl.load(k_ptrs + boff * stride_kb) + k_s = tl.load(k_s_ptrs + boff * stride_ksb) + + logits = tl.zeros((BLOCK_H, BLOCK_N), dtype=tl.float32) + logits = tl.dot(q, k, acc=logits) + logits = tl.maximum(logits, 0) * q_s[:, None] + logits_sum = tl.sum(logits, axis=0) * k_s + + tl.store(o_ptrs, logits_sum, mask=offs_n + boff_id * BLOCK_N < k_seqlen) + boff_ptr += NUM_SPLIT * stride_boff1 + o_ptrs += NUM_SPLIT * BLOCK_N * stride_on + + +def fp8_index(q: torch.Tensor, + q_s: torch.Tensor, + k_cache: torch.Tensor, + k_s_cache: torch.Tensor, + cu_seqlen_q: torch.Tensor, + k_seqlens: torch.Tensor, + block_offset: torch.Tensor, + max_q_seqlen: int = None, + max_k_seqlen: int = None): + """Fp8 index. + + q: (cum_seqlen, num_heads, head_dim) + q_s: (cum_seqlen, num_heads) + k_cache: (num_blocks, block_size, head_dim) + k_s_cache: (num_blocks, block_size) + cu_seqlen_q: (batch_size,) + cu_seqlen_k: (batch_size,) + block_offset: (batch_size, num_blocks) + """ + assert q.dim() == 3 + assert k_cache.dim() == 3 + assert q_s.dim() == 2 + assert k_s_cache.dim() == 2 + cum_seqlen, num_heads, head_dim = q.shape + block_size = k_cache.size(1) + batch_size = k_seqlens.numel() + is_decoding = batch_size == cum_seqlen + if max_k_seqlen is None: + max_num_blocks = k_cache.size(0) + max_k_seqlen = max_num_blocks * block_size + + # max q seqlen + if is_decoding: + if max_q_seqlen is None: + max_q_seqlen = 1 + assert max_q_seqlen == 1 + elif max_q_seqlen is None: + max_q_seqlen = cum_seqlen + + assert q.stride(-1) == 1 and k_cache.stride(-1) == 1 + + out = q.new_empty((cum_seqlen, max_k_seqlen), dtype=torch.float32) + + num_warps = 4 + device_idx = q.device.index + props = get_device_props(device_idx) + num_sm = props['multi_processor_count'] + # estimated occupancy 12.5% + warps_per_sm = props['warps_per_sm'] // 8 + assert warps_per_sm >= num_warps + cta_per_sm = warps_per_sm // num_warps + cta_per_device = num_sm * cta_per_sm + # we better have a tensor to indicate batch id of each q + M = max_q_seqlen * batch_size + NUM_SPLIT = max(1, triton.cdiv(cta_per_device, M)) + NUM_SPLIT = 1 + grid = (M, NUM_SPLIT) + + _fp8_index_kernel[grid](q, + q_s, + k_cache, + k_s_cache, + cu_seqlen_q, + k_seqlens, + block_offset, + out, + *q.stride(), + *q_s.stride(), + *k_cache.stride(), + *k_s_cache.stride(), + *block_offset.stride(), + *out.stride(), + max_q_seqlen=max_q_seqlen, + BLOCK_H=num_heads, + BLOCK_N=block_size, + BLOCK_D=head_dim, + NUM_SPLIT=NUM_SPLIT, + num_warps=num_warps) + return out diff --git a/lmdeploy/pytorch/kernels/cuda/fill_kv_cache.py b/lmdeploy/pytorch/kernels/cuda/fill_kv_cache.py index eeaad8f989..9eb2e76046 100644 --- a/lmdeploy/pytorch/kernels/cuda/fill_kv_cache.py +++ b/lmdeploy/pytorch/kernels/cuda/fill_kv_cache.py @@ -1,6 +1,7 @@ # Copyright (c) OpenMMLab. All rights reserved. -from typing import Literal +from typing import Literal, Optional +import torch import triton import triton.language as tl from torch import Tensor @@ -266,9 +267,9 @@ def _fill_kv_cache_quant_kernel( def fill_kv_cache(k_states: Tensor, - v_states: Tensor, + v_states: Optional[Tensor], k_caches: Tensor, - v_caches: Tensor, + v_caches: Optional[Tensor], q_start_loc: Tensor, q_seq_length: Tensor, kv_seq_length: Tensor, @@ -285,6 +286,10 @@ def fill_kv_cache(k_states: Tensor, b_dim, s_dim, h_dim, d_dim = (0, 2, 1, 3) else: raise RuntimeError('Unsupported layout.') + if v_states is None: + v_states = k_states[..., :0] + if v_caches is None: + v_caches = k_caches[..., :0] block_offsets = block_offsets.contiguous() batch_size = block_offsets.size(0) @@ -386,3 +391,241 @@ def fill_kv_cache(k_states: Tensor, num_warps=4, num_stages=3, ) + + +@triton.jit +def _quant_blocked_fp8(x, + fp8_min: tl.constexpr, + fp8_max: tl.constexpr, + dtype: tl.constexpr, + GROUP_SIZE: tl.constexpr = 128): + x = x.to(tl.float32) + M: tl.constexpr = x.shape[0] + N: tl.constexpr = x.shape[1] + rfp8_max: tl.constexpr = 1 / fp8_max + x = x.reshape(M, N // GROUP_SIZE, GROUP_SIZE) + scale = tl.maximum(tl.max(tl.abs(x), axis=2, keep_dims=True), 1e-6) * rfp8_max + out = x / scale + + out = tl.clamp(out, fp8_min, fp8_max) + out = out.to(dtype) + out = out.reshape(M, N) + scale = scale.reshape(M, N // GROUP_SIZE) + return out, scale + + +@triton.jit +def _fill_kv_cache_blocked_fp8_kernel( + KStates, + VStates, + KCaches, + VCaches, + KSCaches, + VSCaches, + cu_seqlen_q_ptr, + KVSeqLens, + BlockOffsets, + fp8_min: tl.constexpr, + fp8_max: tl.constexpr, + is_decoding: tl.constexpr, + head_dim: tl.constexpr, + head_dim_v: tl.constexpr, + stride_kss, + stride_ksh, + stride_ksd, + stride_vss, + stride_vsh, + stride_vsd, + stride_kcn: tl.constexpr, + stride_kcb: tl.constexpr, + stride_kch: tl.constexpr, + stride_kcd: tl.constexpr, + stride_vcn: tl.constexpr, + stride_vcb: tl.constexpr, + stride_vch: tl.constexpr, + stride_vcd: tl.constexpr, + stride_kscn: tl.constexpr, + stride_kscb: tl.constexpr, + stride_ksch: tl.constexpr, + stride_kscd: tl.constexpr, + stride_vscn: tl.constexpr, + stride_vscb: tl.constexpr, + stride_vsch: tl.constexpr, + stride_vscd: tl.constexpr, + stride_boff, + GROUP_SIZE: tl.constexpr, + BLOCK: tl.constexpr, + BLOCK_D: tl.constexpr, + BLOCK_DV: tl.constexpr, +): + """Fill kv cache kernel.""" + batch_id = tl.program_id(2) + head_id = tl.program_id(0) + block_id = tl.program_id(1) + + q_startloc = tl.load(cu_seqlen_q_ptr + batch_id) + q_seqlen = tl.load(cu_seqlen_q_ptr + batch_id + 1) - q_startloc + kv_seqlen = tl.load(KVSeqLens + batch_id) + history_seqlen = kv_seqlen - q_seqlen + + kv_block_id = history_seqlen // BLOCK + block_id + + if kv_seqlen <= 0: + return + + if kv_block_id * BLOCK >= kv_seqlen: + return + + if is_decoding: + page_offs = tl.full((1, ), history_seqlen % BLOCK, dtype=tl.int32) + kv_mask = tl.full((1, ), 1, dtype=tl.int1) + q_offs = tl.full((1, ), q_startloc, dtype=tl.int32) + else: + page_offs = tl.arange(0, BLOCK) + kv_offs = kv_block_id * BLOCK + page_offs + kv_mask = (kv_offs >= history_seqlen) & (kv_offs < kv_seqlen) + token_off = q_startloc + kv_block_id * BLOCK - history_seqlen + q_offs = token_off + page_offs + + block_off = tl.load(BlockOffsets + batch_id * stride_boff + kv_block_id) + + d_off = tl.arange(0, BLOCK_D) + mask_ks = kv_mask[:, None] + mask_kc = mask_ks & (d_off[None, :] < head_dim) + d_off = d_off % head_dim + + BLOCK_DS: tl.constexpr = (BLOCK_D + GROUP_SIZE - 1) // GROUP_SIZE + ds_off = tl.arange(0, BLOCK_DS) + + ks_ptr = KStates + head_id * stride_ksh + ks_ptrs = ks_ptr + q_offs[:, None] * stride_kss + d_off[None, :] * stride_ksd + kc_ptr = KCaches + block_off * stride_kcn + head_id * stride_kch + kc_ptrs = kc_ptr + page_offs[:, None] * stride_kcb + d_off[None, :] * stride_kcd + ksc_ptr = KSCaches + block_off * stride_kscn + head_id * stride_ksch + ksc_ptrs = ksc_ptr + page_offs[:, None] * stride_kscb + ds_off[None, :] * stride_kscd + + if BLOCK_DV > 0: + dv_off = tl.arange(0, BLOCK_DV) + mask_vs = kv_mask[:, None] + mask_vc = mask_vs & (dv_off[None, :] < head_dim_v) + + BLOCK_DVS: tl.constexpr = (BLOCK_DV + GROUP_SIZE - 1) // GROUP_SIZE + dvs_off = tl.arange(0, BLOCK_DVS) + + dv_off = dv_off % head_dim_v + vs_ptr = VStates + head_id * stride_vsh + vs_ptrs = vs_ptr + q_offs[:, None] * stride_vss + dv_off[None, :] * stride_vsd + vc_ptr = VCaches + block_off * stride_vcn + head_id * stride_vch + vc_ptrs = vc_ptr + page_offs[:, None] * stride_vcb + dv_off[None, :] * stride_vcd + vsc_ptr = VSCaches + block_off * stride_vscn + head_id * stride_vsch + vsc_ptrs = vsc_ptr + page_offs[:, None] * stride_vscb + dvs_off[None, :] * stride_vscd + + k = tl.load(ks_ptrs, mask=mask_ks) + if BLOCK_DV > 0: + v = tl.load(vs_ptrs, mask=mask_vs) + kc, kcs = _quant_blocked_fp8(k, fp8_min, fp8_max, KCaches.dtype.element_ty, GROUP_SIZE) + tl.store(kc_ptrs, kc, mask=mask_kc) + tl.store(ksc_ptrs, kcs, mask=kv_mask[:, None] & (ds_off[None, :] < _div_up(head_dim, GROUP_SIZE))) + if BLOCK_DV > 0: + vc, vcs = _quant_blocked_fp8(v, fp8_min, fp8_max, VCaches.dtype.element_ty, GROUP_SIZE) + tl.store(vc_ptrs, vc, mask=mask_vc) + tl.store(vsc_ptrs, vcs, mask=kv_mask[:, None] & (ds_off[None, :] < _div_up(head_dim_v, GROUP_SIZE))) + + +def fill_kv_cache_blocked_fp8(k_states: Tensor, + v_states: Optional[Tensor], + k_caches: Tensor, + v_caches: Optional[Tensor], + ks_caches: Tensor, + vs_caches: Optional[Tensor], + cu_seqlen_q: Tensor, + kv_seqlens: Tensor, + max_q_seqlen: int, + block_offsets: Tensor, + group_size: int = 128, + kv_layout: str = 'bshd'): + """Fill key/value state to cache for paged attention with fp8 + quantization.""" + + if kv_layout == 'bshd': + b_dim, s_dim, h_dim, d_dim = (0, 1, 2, 3) + elif kv_layout == 'bhsd': + b_dim, s_dim, h_dim, d_dim = (0, 2, 1, 3) + else: + raise RuntimeError('Unsupported layout.') + + if v_states is None: + v_states = k_states[..., :0] + if v_caches is None: + v_caches = k_caches[..., :0] + if vs_caches is None: + vs_caches = ks_caches[..., :0] + + block_offsets = block_offsets.contiguous() + batch_size = block_offsets.size(0) + block_size = k_caches.size(s_dim) + num_heads = k_caches.size(h_dim) + head_dim = k_caches.size(d_dim) + head_dim_v = v_states.size(-1) + if max_q_seqlen == 1: + max_num_blocks = 1 + else: + max_num_blocks = triton.cdiv(max_q_seqlen, block_size) + 1 + + BLOCK = block_size + BLOCK_D = triton.next_power_of_2(head_dim) + BLOCK_DV = triton.next_power_of_2(head_dim_v) + if k_caches.data_ptr() == v_caches.data_ptr() and head_dim_v <= head_dim: + BLOCK_DV = 0 + + dtype = k_caches.dtype + finfo = torch.finfo(dtype) + fmin = finfo.min + fmax = finfo.max + + grid = (num_heads, max_num_blocks, batch_size) + is_decoding = max_q_seqlen == 1 + _fill_kv_cache_blocked_fp8_kernel[grid]( + k_states, + v_states, + k_caches, + v_caches, + ks_caches, + vs_caches, + cu_seqlen_q, + kv_seqlens, + block_offsets, + fp8_min=fmin, + fp8_max=fmax, + is_decoding=is_decoding, + head_dim=head_dim, + head_dim_v=head_dim_v, + stride_kss=k_states.stride(-3), + stride_ksh=k_states.stride(-2), + stride_ksd=k_states.stride(-1), + stride_vss=v_states.stride(-3), + stride_vsh=v_states.stride(-2), + stride_vsd=v_states.stride(-1), + stride_kcn=k_caches.stride(b_dim), + stride_kcb=k_caches.stride(s_dim), + stride_kch=k_caches.stride(h_dim), + stride_kcd=k_caches.stride(d_dim), + stride_vcn=v_caches.stride(b_dim), + stride_vcb=v_caches.stride(s_dim), + stride_vch=v_caches.stride(h_dim), + stride_vcd=v_caches.stride(d_dim), + stride_kscn=ks_caches.stride(b_dim), + stride_kscb=ks_caches.stride(s_dim), + stride_ksch=ks_caches.stride(h_dim), + stride_kscd=ks_caches.stride(d_dim), + stride_vscn=vs_caches.stride(b_dim), + stride_vscb=vs_caches.stride(s_dim), + stride_vsch=vs_caches.stride(h_dim), + stride_vscd=vs_caches.stride(d_dim), + stride_boff=block_offsets.stride(0), + GROUP_SIZE=group_size, + BLOCK=BLOCK, + BLOCK_D=BLOCK_D, + BLOCK_DV=BLOCK_DV, + num_warps=4, + ) diff --git a/lmdeploy/pytorch/kernels/cuda/flash_mla.py b/lmdeploy/pytorch/kernels/cuda/flash_mla.py deleted file mode 100644 index 1a3209edeb..0000000000 --- a/lmdeploy/pytorch/kernels/cuda/flash_mla.py +++ /dev/null @@ -1,45 +0,0 @@ -# Copyright (c) OpenMMLab. All rights reserved. -from typing import Optional, Tuple - -import torch - - -def flash_mla_fwd( - q: torch.Tensor, - k_cache: torch.Tensor, - block_table: torch.Tensor, - cache_seqlens: torch.Tensor, - head_dim_v: int, - tile_scheduler_metadata: torch.Tensor, - num_splits: torch.Tensor, - softmax_scale: Optional[float] = None, - causal: bool = False, -) -> Tuple[torch.Tensor, torch.Tensor]: - """ - Arguments: - q: (batch_size, num_heads_q, head_dim). - k_cache: (num_blocks, page_block_size, num_heads_k, head_dim). - block_table: (batch_size, max_num_blocks_per_seq), torch.int32. - cache_seqlens: (batch_size), torch.int32. - head_dim_v: Head_dim of v. - tile_scheduler_metadata: (num_sm_parts, TileSchedulerMetaDataSize), torch.int32, return by get_mla_metadata. - num_splits: (batch_size + 1), torch.int32, return by get_mla_metadata. - softmax_scale: float. The scaling of QK^T before applying softmax. Default to 1 / sqrt(head_dim). - causal: bool. Whether to apply causal attention mask. - - Return: - out: (batch_size, num_heads_q, head_dim_v). - """ - import flash_mla - out, _ = flash_mla.flash_mla_with_kvcache( - q, - k_cache, - block_table, - cache_seqlens, - head_dim_v, - tile_scheduler_metadata, - num_splits, - softmax_scale, - causal, - ) - return out.squeeze(1) diff --git a/lmdeploy/pytorch/kernels/cuda/flatten_kv_cache.py b/lmdeploy/pytorch/kernels/cuda/flatten_kv_cache.py index bd82e347ed..facb9a69f4 100644 --- a/lmdeploy/pytorch/kernels/cuda/flatten_kv_cache.py +++ b/lmdeploy/pytorch/kernels/cuda/flatten_kv_cache.py @@ -333,3 +333,174 @@ def flatten_kv_cache(k_caches: Tensor, ) return k_states, v_states + + +@triton.jit +def dequant_fp8(x, scale, GROUP_SIZE: tl.constexpr): + """Dequant fp8.""" + M: tl.constexpr = x.shape[0] + N: tl.constexpr = x.shape[1] + x = x.to(scale.dtype) + x = x.reshape(M, N // GROUP_SIZE, GROUP_SIZE) + scale = scale.reshape(M, N // GROUP_SIZE, 1) + x = x * scale + x = x.reshape(M, N) + return x + + +@triton.jit +def flatten_kv_cache_mla_fp8_kernel( + kc_nope_ptr, + kc_scale_ptr, + kc_pe_ptr, + ko_ptr, + start_loc_ptr, + seqlens_ptr, + block_offsets_ptr, + stride_kcb: tl.constexpr, + stride_kcs: tl.constexpr, + stride_kch: tl.constexpr, + stride_kcd: tl.constexpr, + stride_kcsb: tl.constexpr, + stride_kcss: tl.constexpr, + stride_kcsh: tl.constexpr, + stride_kcsd: tl.constexpr, + stride_kcpb: tl.constexpr, + stride_kcps: tl.constexpr, + stride_kcph: tl.constexpr, + stride_kcpd: tl.constexpr, + stride_koh, + stride_kos: tl.constexpr, + stride_kod: tl.constexpr, + stride_boff, + OUT_SIZE, + BLOCK_BS: tl.constexpr, + BLOCK_NOPE: tl.constexpr, + BLOCK_PE: tl.constexpr, + GROUP_SIZE: tl.constexpr, +): + """Mla fp8 flatten kv cache kernel.""" + page_id = tl.program_id(0) + batch_id = tl.program_id(1) + head_id = tl.program_id(2) + num_batches = tl.num_programs(1) + + seqlen = tl.load(seqlens_ptr + batch_id) + start_loc = tl.load(start_loc_ptr + batch_id) + # fill last block to prevent attention nan + if batch_id == num_batches - 1: + seqlen = OUT_SIZE - start_loc + if page_id * BLOCK_BS >= seqlen: + return + + b_off = tl.load(block_offsets_ptr + batch_id * stride_boff + page_id) + + BLOCK_SCALE: tl.constexpr = BLOCK_NOPE // GROUP_SIZE + offs_bs = tl.arange(0, BLOCK_BS) + offs_dnope = tl.arange(0, BLOCK_NOPE) + offs_scale = tl.arange(0, BLOCK_SCALE) + offs_dpe = tl.arange(0, BLOCK_PE) + offs_obs = page_id * BLOCK_BS + tl.arange(0, BLOCK_BS) + mask_bs = offs_obs < seqlen + + offs_kc = b_off * stride_kcb + offs_bs[:, None] * stride_kcs + head_id * stride_kch + kc_nope_ptrs = (kc_nope_ptr + offs_kc + offs_dnope[None, :] * stride_kcd) + + offs_kc_scale = b_off * stride_kcsb + offs_bs[:, None] * stride_kcss + head_id * stride_kcsh + kc_scale_ptrs = (kc_scale_ptr + offs_kc_scale + offs_scale[None, :] * stride_kcsd) + + offs_kc_pe = b_off * stride_kcpb + offs_bs[:, None] * stride_kcps + head_id * stride_kcph + kc_pe_ptrs = (kc_pe_ptr + offs_kc_pe + offs_dpe[None, :] * stride_kcpd) + + offs_ko = head_id * stride_koh + (start_loc + offs_obs[:, None]) * stride_kos + ko_nope_ptrs = (ko_ptr + offs_ko + offs_dnope[None, :] * stride_kod) + ko_pe_ptrs = (ko_ptr + offs_ko + (BLOCK_NOPE + offs_dpe[None, :]) * stride_kod) + + # nope + kc_nope = tl.load(kc_nope_ptrs) + kc_scale = tl.load(kc_scale_ptrs) + ko_nope = dequant_fp8(kc_nope, kc_scale, GROUP_SIZE) + ko_nope = ko_nope.to(ko_ptr.dtype.element_ty) + tl.store(ko_nope_ptrs, ko_nope, mask=mask_bs[:, None]) + + # pe + kc_pe = tl.load(kc_pe_ptrs) + tl.store(ko_pe_ptrs, kc_pe, mask=mask_bs[:, None]) + + +def flatten_kv_cache_mla_fp8(k_caches: Tensor, + seqlens: Tensor, + block_offsets: Tensor, + start_loc: Tensor = None, + out_size: int = None, + out_dtype: torch.dtype = None, + flatten_kv_layout: str = 'hsd'): + """This kernel is designed to support mla fp8.""" + assert k_caches.dim() == 4 + + b_dim, s_dim, h_dim, d_dim = (0, 1, 2, 3) + + if out_dtype is None: + out_dtype = torch.bfloat16 + + if out_size is None or out_size <= 0: + out_size = k_caches.size(b_dim) * k_caches.size(s_dim) + + # TODO: DIRTY magic number + k_caches_nope = k_caches[..., :512] + k_caches_scale = k_caches[..., 512:512 + 16].view(torch.float32) + k_caches_pe = k_caches[..., 512 + 16:].view(out_dtype) + + if start_loc is None: + start_loc = seqlens.cumsum(0) - seqlens + + batch_size, num_blocks = block_offsets.size() + num_heads = k_caches.size(h_dim) + k_head_dim = 576 + BLOCK_NOPE = 512 + BLOCK_PE = 64 + BLOCK_BS = k_caches.size(s_dim) + if flatten_kv_layout == 'hsd': + k_states = k_caches.new_empty(num_heads, out_size, k_head_dim, dtype=out_dtype) + stride_koh = k_states.stride(0) + stride_kos = k_states.stride(1) + elif flatten_kv_layout == 'shd': + k_states = k_caches.new_empty(out_size, num_heads, k_head_dim, dtype=out_dtype) + stride_koh = k_states.stride(1) + stride_kos = k_states.stride(0) + else: + raise RuntimeError(f'Unsupported layout: {flatten_kv_layout}.') + + grid = (num_blocks, batch_size, num_heads) + flatten_kv_cache_mla_fp8_kernel[grid]( + k_caches_nope, + k_caches_scale, + k_caches_pe, + k_states, + start_loc, + seqlens, + block_offsets, + stride_kcb=k_caches_nope.stride(b_dim), + stride_kcs=k_caches_nope.stride(s_dim), + stride_kch=k_caches_nope.stride(h_dim), + stride_kcd=k_caches_nope.stride(d_dim), + stride_kcsb=k_caches_scale.stride(b_dim), + stride_kcss=k_caches_scale.stride(s_dim), + stride_kcsh=k_caches_scale.stride(h_dim), + stride_kcsd=k_caches_scale.stride(d_dim), + stride_kcpb=k_caches_pe.stride(b_dim), + stride_kcps=k_caches_pe.stride(s_dim), + stride_kcph=k_caches_pe.stride(h_dim), + stride_kcpd=k_caches_pe.stride(d_dim), + stride_koh=stride_koh, + stride_kos=stride_kos, + stride_kod=k_states.stride(2), + stride_boff=block_offsets.stride(0), + OUT_SIZE=out_size, + BLOCK_BS=BLOCK_BS, + BLOCK_NOPE=BLOCK_NOPE, + BLOCK_PE=BLOCK_PE, + GROUP_SIZE=128, + ) + + return k_states diff --git a/lmdeploy/pytorch/model_inputs.py b/lmdeploy/pytorch/model_inputs.py index 13e35fd1ae..1a636ce380 100644 --- a/lmdeploy/pytorch/model_inputs.py +++ b/lmdeploy/pytorch/model_inputs.py @@ -9,7 +9,7 @@ # from torch import distributed as dist import lmdeploy.pytorch.distributed as dist from lmdeploy.pytorch.backends import get_backend -from lmdeploy.pytorch.config import DLLMConfig, ModelConfig +from lmdeploy.pytorch.config import CacheConfig, DLLMConfig, ModelConfig from lmdeploy.pytorch.multimodal.data_type import MultiModalTensor if TYPE_CHECKING: @@ -144,6 +144,7 @@ class ModelInputs: model_metas: List[Dict[str, Any]] = None dp_meta: 'DPMeta' = None enable_microbatch: bool = False + is_dummy: bool = False def step(self, input_ids: torch.LongTensor, step_seqlens: torch.Tensor = None): """Update input ids.""" @@ -239,6 +240,7 @@ def __make_next_vision_inputs(flatten_mms: List, start: int): end = min(max_seq_len, start + split_size) max_q_seqlen = end - start + max_kv_seqlen += max_q_seqlen if isinstance(max_q_seqlen, torch.Tensor): max_q_seqlen = max_q_seqlen.item() max_kv_seqlen += max_q_seqlen @@ -299,6 +301,7 @@ class StepContext: """ input_ids: torch.LongTensor model_config: ModelConfig + cache_config: CacheConfig block_offsets: torch.IntTensor position_ids: torch.LongTensor attention_mask: torch.LongTensor @@ -329,6 +332,7 @@ def new( cls, inputs: ModelInputs, model_config: ModelConfig, + cache_config: CacheConfig, kv_caches: List = None, kv_quant_policy: Literal[0, 4, 8] = 0, ): @@ -365,10 +369,13 @@ def new( # seq_len + history_length kv_seqlens = q_seqlens + history_seqlens kv_seqlens -= inputs.num_ignored_history + if inputs.is_dummy: + kv_seqlens = torch.zeros_like(kv_seqlens) ret = StepContext( input_ids=inputs.input_ids, model_config=model_config, + cache_config=cache_config, block_offsets=inputs.block_offsets, position_ids=position_ids, input_embeddings=input_embeddings, @@ -453,6 +460,7 @@ def build_context( self, inputs: ModelInputs, model_config: ModelConfig, + cache_config: CacheConfig, kv_caches: List = None, kv_quant_policy: Literal[0, 4, 8] = 0, ): @@ -460,6 +468,7 @@ def build_context( return StepContext.new( inputs, model_config, + cache_config, kv_caches, kv_quant_policy, ) diff --git a/lmdeploy/pytorch/models/deepseek_v2.py b/lmdeploy/pytorch/models/deepseek_v2.py index a10e5da520..83cf98ea0c 100644 --- a/lmdeploy/pytorch/models/deepseek_v2.py +++ b/lmdeploy/pytorch/models/deepseek_v2.py @@ -1285,6 +1285,14 @@ def __skip_nextn(name, nextn_keys): return True return False + def __skip_layers(): + """We might change the number of layers so we can debug the model + with less gpus.""" + import re + matches = re.findall(r'\.layers\.(\d+)\.', name) + layer_id = int(matches[0]) + return layer_id >= self.config.num_hidden_layers + stacked_params_mapping = [ # (param_name, shard_name, shard_id) ('.gate_up_proj', '.gate_proj', 0), @@ -1334,6 +1342,10 @@ def __skip_nextn(name, nextn_keys): # skip nextn if __skip_nextn(name, nextn_keys): continue + + if __skip_layers(): + continue + if self.config.tie_word_embeddings and 'lm_head.weight' in name: continue diff --git a/lmdeploy/pytorch/models/deepseek_v32.py b/lmdeploy/pytorch/models/deepseek_v32.py new file mode 100644 index 0000000000..75b776249f --- /dev/null +++ b/lmdeploy/pytorch/models/deepseek_v32.py @@ -0,0 +1,388 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from typing import Any, Sequence, Tuple + +import torch +from torch import nn + +from lmdeploy.pytorch.distributed import get_dist_manager, get_ep_world_rank +from lmdeploy.pytorch.model_inputs import StepContextManager +from lmdeploy.pytorch.nn import (ApplyRotaryEmb, Attention, RMSNorm, RopeType, build_rotary_embedding, + build_rotary_params) +from lmdeploy.pytorch.nn.eplb import EPLBManager +from lmdeploy.pytorch.nn.linear import build_colwise_linear, build_o_proj, build_rowwise_linear +from lmdeploy.pytorch.nn.nsa import IndexerTopKFP8 + +from .deepseek_v2 import (DeepseekV2Attention, DeepseekV2BMM, DeepseekV2DecoderLayer, DeepseekV2ForCausalLM, + DeepseekV2MLP, DeepseekV2Model, DeepseekV2MoE, yarn_get_mscale) + + +def rotate_activation(x: torch.Tensor) -> torch.Tensor: + assert x.dtype == torch.bfloat16 + from fast_hadamard_transform import hadamard_transform + hidden_size = x.size(-1) + return hadamard_transform(x, scale=hidden_size**-0.5) + + +class Indexer(nn.Module): + + def __init__(self, config: Any, dtype: torch.dtype = None, device: torch.device = None): + super().__init__() + try: + import fast_hadamard_transform # noqa: F401 + except ImportError: + raise ImportError('Please install fast_hadamard_transform package.') + quant_config = getattr(config, 'quantization_config', None) + # self.dim: int = 2048 + self.dim: int = config.hidden_size + self.n_heads: int = config.index_n_heads + self.n_local_heads = config.index_n_heads + self.head_dim: int = config.index_head_dim + self.rope_head_dim: int = config.qk_rope_head_dim + self.index_topk: int = config.index_topk + self.q_lora_rank: int = config.q_lora_rank + self.wq_b = build_colwise_linear(self.q_lora_rank, + self.n_heads * self.head_dim, + bias=False, + dtype=dtype, + device=device, + is_tp=False, + quant_config=quant_config) + self.wk = build_colwise_linear(self.dim, + self.head_dim, + bias=False, + dtype=dtype, + device=device, + is_tp=False, + quant_config=quant_config) + self.k_norm = nn.LayerNorm(self.head_dim, dtype=dtype, device=device) + self.weights_proj = build_colwise_linear(self.dim, + self.n_heads, + bias=False, + dtype=dtype, + device=device, + is_tp=False) + self.softmax_scale = self.head_dim**-0.5 + self.scale_fmt = quant_config['scale_fmt'] + self.apply_rotary_pos_emb = ApplyRotaryEmb() + self.indexer_topk = IndexerTopKFP8(self.index_topk, self.softmax_scale, block_size=128, fill=-1) + + def forward(self, + x: torch.Tensor, + qr: torch.Tensor, + freqs_cis: torch.Tensor, + index_cache: Tuple[torch.Tensor, torch.Tensor], + attn_metadata: Any = None): + q = self.wq_b(qr) + q = q.unflatten(-1, (-1, self.head_dim)) + q_pe, q_nope = torch.split(q, [self.rope_head_dim, self.head_dim - self.rope_head_dim], dim=-1) + k = self.wk(x) + k = self.k_norm(k) + k_pe, k_nope = torch.split(k, [self.rope_head_dim, self.head_dim - self.rope_head_dim], dim=-1) + + # apply rotary embedding + cos, sin = freqs_cis + q_pe, k_pe = self.apply_rotary_pos_emb( + q_pe, + k_pe[..., None, :], + cos, + sin, + inplace=False, + ) + k_pe = k_pe[0, :] + k_nope = k_nope[0, :, None] + q = torch.cat([q_pe, q_nope], dim=-1) + k = torch.cat([k_pe, k_nope], dim=-1) + q = rotate_activation(q) + k = rotate_activation(k) + + weights = self.weights_proj(x) * self.n_heads**-0.5 + + return self.indexer_topk(q[0], k[:, 0], weights[0], index_cache[0], index_cache[1], attn_metadata=attn_metadata) + + +class DeepseekV32Attention(DeepseekV2Attention): + + def __init__(self, config: Any, dtype: torch.dtype = None, device: torch.device = None): + nn.Module.__init__(self) + quantization_config = getattr(config, 'quantization_config', None) + self.q_lora_rank = config.q_lora_rank + self.hidden_size = config.hidden_size + self.num_heads = config.num_attention_heads + self.qk_rope_head_dim = config.qk_rope_head_dim + self.kv_lora_rank = config.kv_lora_rank + self.v_head_dim = config.v_head_dim + self.qk_nope_head_dim = config.qk_nope_head_dim + self.q_head_dim = config.qk_nope_head_dim + config.qk_rope_head_dim + num_replicate_kv_heads = getattr(config, 'num_replicate_key_value_heads', 1) + num_key_value_heads = getattr(config, 'num_key_value_heads', 1) + use_flash_mla = getattr(config, 'use_flash_mla', False) + + if self.q_lora_rank is None: + self.q_proj = build_colwise_linear( + self.hidden_size, + self.num_heads * self.q_head_dim, + bias=False, + dtype=dtype, + device=device, + is_tp=True, + quant_config=quantization_config, + dp_disable_tp=True, + ) + else: + self.q_a_proj = build_colwise_linear( + self.hidden_size, + config.q_lora_rank, + bias=config.attention_bias, + dtype=dtype, + device=device, + is_tp=False, + quant_config=quantization_config, + ) + self.q_a_layernorm = RMSNorm(config.q_lora_rank, + 1e-6, + quant_config=quantization_config, + dtype=dtype, + device=device) + self.q_b_proj = build_colwise_linear( + config.q_lora_rank, + self.num_heads * self.q_head_dim, + bias=False, + dtype=dtype, + device=device, + is_tp=True, + quant_config=quantization_config, + dp_disable_tp=True, + ) + + self.kv_a_proj_with_mqa = build_colwise_linear( + self.hidden_size, + config.kv_lora_rank + config.qk_rope_head_dim, + bias=config.attention_bias, + dtype=dtype, + device=device, + is_tp=False, + quant_config=quantization_config, + ) + self.kv_a_layernorm = RMSNorm(config.kv_lora_rank, + 1e-6, + quant_config=quantization_config, + dtype=dtype, + device=device) + self.kc = DeepseekV2BMM(self.num_heads, + config.qk_nope_head_dim, + config.kv_lora_rank, + dtype=dtype, + device=device) + + self.apply_rotary_pos_emb = ApplyRotaryEmb() + + self.softmax_scale = self.q_head_dim**(-0.5) + + if config.rope_scaling is not None: + mscale_all_dim = config.rope_scaling.get('mscale_all_dim', 0) + scaling_factor = config.rope_scaling['factor'] + if mscale_all_dim: + mscale = yarn_get_mscale(scaling_factor, mscale_all_dim) + self.softmax_scale = self.softmax_scale * mscale * mscale + + self.attn_fwd = Attention(self.num_heads, + config.kv_lora_rank + self.qk_rope_head_dim, + scale=self.softmax_scale, + num_kv_heads=num_key_value_heads, + v_head_size=config.kv_lora_rank, + num_replicate_kv_heads=num_replicate_kv_heads, + use_flash_mla=use_flash_mla) + + self.vc = DeepseekV2BMM(self.num_heads, config.kv_lora_rank, self.v_head_dim, dtype=dtype, device=device) + self.o_proj = build_o_proj( + self.num_heads * self.v_head_dim, + self.hidden_size, + bias=config.attention_bias, + dtype=dtype, + device=device, + is_tp=True, + quant_config=quantization_config, + ) + + self.indexer = Indexer(config, dtype=dtype, device=device) + + def _q_proj(self, hidden_states, num_heads: int, nope_size: int, pe_size: int): + """Q proj.""" + q_len = hidden_states.size(1) + + query_states = hidden_states.new_empty(q_len, num_heads, nope_size + pe_size) + + if self.q_lora_rank is None: + qr = hidden_states + q = self.q_proj(hidden_states) + else: + qr = self.q_a_layernorm(self.q_a_proj(hidden_states)) + q = self.q_b_proj(qr) + q = q.view(q_len, num_heads, self.q_head_dim) + # q_pe: (q_len, num_heads, qk_rope_head_dim) + q_nope, q_pe = torch.split(q, [self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1) + # q_nope: (q_len, num_heads, kv_lora_rank) + q_nope_out = query_states[..., :nope_size] + self.kc(q_nope, q_nope_out) + return query_states, q_pe, qr + + def _kv_proj(self, hidden_states, nope_size: int): + """Kv proj.""" + # (q_len, 1, nope_size + pe_size) + key_states = self.kv_a_proj_with_mqa(hidden_states[0, :, None]) + # (q_len, 1, pe_size) + k_pe = key_states[..., nope_size:] + # kv_a_layernorm + value_states = key_states[..., :nope_size] + value_states = self.kv_a_layernorm(value_states) + key_states[..., :nope_size] = value_states + return key_states, value_states, k_pe + + def _qkv_proj(self, hidden_states: torch.Tensor, num_heads: int): + """Qkv proj.""" + nope_size = self.kv_lora_rank + pe_size = self.qk_rope_head_dim + query_states, q_pe, qr = self._q_proj(hidden_states, num_heads, nope_size, pe_size) + key_states, value_states, k_pe = self._kv_proj(hidden_states, nope_size) + + return query_states, key_states, value_states, q_pe, k_pe, qr + + def forward( + self, + hidden_states: torch.Tensor, + rotary_pos_emb: Tuple[torch.FloatTensor, torch.FloatTensor], + past_key_value: Sequence[torch.Tensor] = None, + attn_metadata: Any = None, + ): + """Rewrite of LlamaAttention.forward.""" + dist_ctx = get_dist_manager().current_context() + if dist_ctx.dp > 1: + num_heads = self.num_heads + else: + world_size = dist_ctx.world_size + num_heads = self.num_heads // world_size + nope_size = self.kv_lora_rank + q_len = hidden_states.size(1) + + # qkv_proj + query_states, key_states, value_states, q_pe, k_pe, qr = self._qkv_proj(hidden_states, num_heads=num_heads) + + cos, sin = rotary_pos_emb + q_pe, k_pe = self.apply_rotary_pos_emb( + q_pe, + k_pe, + cos, + sin, + inplace=False, + ) + query_states[..., nope_size:] = q_pe + key_states[..., nope_size:] = k_pe + + topk_indices = self.indexer(hidden_states, qr, rotary_pos_emb, past_key_value[-2:], attn_metadata=attn_metadata) + + attn_output = self.attn_fwd( + query_states, + key_states, + value_states, + past_key_value[0], + past_key_value[0][..., :nope_size], + attn_metadata, + k_scales_zeros=None if len(past_key_value) == 2 else past_key_value[2], + v_scales_zeros=None if len(past_key_value) == 2 else past_key_value[3], + nsa_indices=topk_indices, + ) + attn_bmm_out = attn_output.new_empty(q_len, num_heads, self.v_head_dim) + + self.vc(attn_output, attn_bmm_out) + attn_output = attn_bmm_out.flatten(-2, -1)[None] + attn_output = self.o_proj(attn_output) + + return attn_output + + +class DeepseekV32DecoderLayer(DeepseekV2DecoderLayer): + + def __init__(self, config: Any, layer_idx: int, dtype: torch.dtype = None, device: torch.device = None): + nn.Module.__init__(self) + self.layer_idx = layer_idx + quantization_config = None + + # build attention layer + if getattr(config, 'use_mla', True): + self.self_attn = DeepseekV32Attention(config, dtype=dtype, device=device) + else: + # deepseek-vl2-tiny uses MHA LlamaAttention structure + from lmdeploy.pytorch.models.llama import LlamaAttention + self.self_attn = LlamaAttention(config, dtype=dtype, device=device) + + # mlp + self.mlp = (DeepseekV2MoE(config, layer_idx, dtype=dtype, device=device) if + (config.n_routed_experts is not None and layer_idx >= config.first_k_dense_replace + and layer_idx % config.moe_layer_freq == 0) else DeepseekV2MLP(config, dtype=dtype, device=device)) + + # build input layer norm + self.input_layernorm = RMSNorm(config.hidden_size, + config.rms_norm_eps, + quant_config=quantization_config, + dtype=dtype, + device=device) + + # build attention layer norm + self.post_attention_layernorm = RMSNorm(config.hidden_size, config.rms_norm_eps, dtype=dtype, device=device) + + +class DeepseekV32Model(DeepseekV2Model): + + def __init__(self, config: Any, dtype: torch.dtype = None, device: torch.device = None): + nn.Module.__init__(self) + self.config = config + self.padding_idx = config.pad_token_id + self.vocab_size = config.vocab_size + self.embed_tokens = nn.Embedding(config.vocab_size, + config.hidden_size, + self.padding_idx, + dtype=dtype, + device=device) + if get_dist_manager().current_context().dist_config.enable_eplb: + ep_size_, _ = get_ep_world_rank() + EPLBManager.init_global_eplb_metadata(ep_size_, config.n_routed_experts, config.num_hidden_layers) + self.layers = nn.ModuleList([ + DeepseekV32DecoderLayer(config, layer_idx, dtype=dtype, device=device) + for layer_idx in range(config.num_hidden_layers) + ]) + + # build norm + self.norm = RMSNorm(config.hidden_size, config.rms_norm_eps, quant_config=None, dtype=dtype, device=device) + + emb_type = RopeType.LinearScaling + rope_dim = config.qk_rope_head_dim if getattr(config, 'use_mla', True) else (config.hidden_size // + config.num_attention_heads) + rope_max_pos_emb = config.max_position_embeddings + rope_base = config.rope_theta + + rope_params = dict(emb_type=emb_type, dim=rope_dim, max_position_embeddings=rope_max_pos_emb, base=rope_base) + update_params = build_rotary_params(config) + rope_params.update(update_params) + self.rotary_emb = build_rotary_embedding(**rope_params) + + +class DeepseekV32ForCausalLM(DeepseekV2ForCausalLM): + + def __init__(self, + config: Any, + ctx_mgr: StepContextManager, + dtype: torch.dtype = None, + device: torch.device = None): + nn.Module.__init__(self) + self.config = config + self.quantization_config = getattr(config, 'quantization_config', None) + self.dtype = dtype + self.ctx_mgr = ctx_mgr + self.model = DeepseekV32Model(config, dtype=dtype, device=device) + # build lm_head + self.lm_head = build_rowwise_linear(config.hidden_size, + config.vocab_size, + bias=False, + dtype=dtype, + device=device) + self._load_buffers = dict() diff --git a/lmdeploy/pytorch/models/module_map.py b/lmdeploy/pytorch/models/module_map.py index 498e2c6554..b90bfe3ba2 100644 --- a/lmdeploy/pytorch/models/module_map.py +++ b/lmdeploy/pytorch/models/module_map.py @@ -100,6 +100,9 @@ # deepseek-v3 MODULE_MAP.update({'DeepseekV3ForCausalLM': f'{LMDEPLOY_PYTORCH_MODEL_PATH}.deepseek_v2.DeepseekV2ForCausalLM'}) +# deepseek-v32 +MODULE_MAP.update({'DeepseekV32ForCausalLM': f'{LMDEPLOY_PYTORCH_MODEL_PATH}.deepseek_v32.DeepseekV32ForCausalLM'}) + # deepseek-vl2 MODULE_MAP.update({'DeepseekVLV2ForCausalLM': f'{LMDEPLOY_PYTORCH_MODEL_PATH}.deepseek_vl2.DeepseekVLV2ForCausalLM'}) diff --git a/lmdeploy/pytorch/models/utils/cudagraph.py b/lmdeploy/pytorch/models/utils/cudagraph.py index 065aef97d1..6d0772c175 100644 --- a/lmdeploy/pytorch/models/utils/cudagraph.py +++ b/lmdeploy/pytorch/models/utils/cudagraph.py @@ -1,6 +1,6 @@ # Copyright (c) OpenMMLab. All rights reserved. from dataclasses import dataclass -from typing import Any, Dict, List +from typing import Any, Dict, List, Optional import torch from torch import Tensor @@ -34,6 +34,9 @@ class CudaGraphMeta: input_buffers: BuffType = None output_buffers: BuffType = None vocab_size: int = 1 + use_mla_fp8_cache: bool = False + use_flash_mla: bool = False + mla_index_topk: Optional[int] = None class CudaGraphMixin: @@ -68,8 +71,16 @@ def make_buffers_cudagraph(self, graph_meta: CudaGraphMeta, *args, **kwargs) -> import flash_mla # create buffers for flash mla + num_attention_heads = self.config.num_attention_heads + index_topk = graph_meta.mla_index_topk + num_heads_q = None if index_topk is None else num_attention_heads input_buffers['tile_scheduler_metadata'], input_buffers['num_splits'] = flash_mla.get_mla_metadata( - torch.ones(max_batches, dtype=torch.int32, device=device), self.config.num_attention_heads, 1) + torch.ones(max_batches, dtype=torch.int32, device=device), + num_attention_heads, + num_heads_k=1, + num_heads_q=num_heads_q, + is_fp8_kvcache=graph_meta.use_mla_fp8_cache, + topk=index_topk) # flash_mla requires block_offsets and kv_lens int32 input_buffers['block_offsets'] = torch.zeros((max_batches, num_blocks), dtype=torch.int32, device=device) @@ -82,6 +93,10 @@ def make_buffers_cudagraph(self, graph_meta: CudaGraphMeta, *args, **kwargs) -> # create buffer for cross_attn_metadata here input_buffers['fill_seqlens'] = torch.zeros(max_batches, dtype=torch.int64, device=device) + input_buffers['cu_seqlens'] = torch.zeros(2, max_batches + 1, dtype=torch.int32, device=device) + input_buffers['cu_seqlens_q'] = input_buffers['cu_seqlens'][0] + input_buffers['cu_seqlens_k'] = input_buffers['cu_seqlens'][1] + return input_buffers def fill_buffers_cudagraph(self, graph_meta: CudaGraphMeta, input_ids: Tensor, position_ids: Tensor, @@ -108,6 +123,8 @@ def fill_buffers_cudagraph(self, graph_meta: CudaGraphMeta, input_ids: Tensor, p qkv = torch.stack((q_start_loc, q_seqlens, kv_seqlens)) input_buffers['qkv_lens'].zero_() input_buffers['qkv_lens'][:, :batch_size] = qkv + input_buffers['cu_seqlens_q'][1:batch_size + 1] = input_buffers['q_seqlens'][:batch_size].cumsum(0) + input_buffers['cu_seqlens_k'][1:batch_size + 1] = input_buffers['kv_seqlens'][:batch_size].cumsum(0) if inputs_embeds is not None: emb_size = inputs_embeds.size(-1) if 'inputs_embeds' not in input_buffers: @@ -121,10 +138,20 @@ def fill_buffers_cudagraph(self, graph_meta: CudaGraphMeta, input_ids: Tensor, p attn_metadata.q_start_loc = input_buffers['q_start_loc'] attn_metadata.q_seqlens = input_buffers['q_seqlens'] attn_metadata.kv_seqlens = input_buffers['kv_seqlens'] + attn_metadata.cu_seqlens_q = input_buffers['cu_seqlens_q'] + attn_metadata.cu_seqlens_k = input_buffers['cu_seqlens_k'] if getattr(self.config, 'use_flash_mla', False) is True: import flash_mla - tile_scheduler_metadata, num_splits = flash_mla.get_mla_metadata(attn_metadata.kv_seqlens.to(torch.int32), - self.config.num_attention_heads, 1) + num_attention_heads = self.config.num_attention_heads + index_topk = graph_meta.mla_index_topk + num_heads_q = None if index_topk is None else num_attention_heads + tile_scheduler_metadata, num_splits = flash_mla.get_mla_metadata( + attn_metadata.kv_seqlens.to(torch.int32), + num_attention_heads, + num_heads_k=1, + num_heads_q=num_heads_q, + is_fp8_kvcache=graph_meta.use_mla_fp8_cache, + topk=index_topk) # here we use copy_ instead of = to avoid using new allocated mem for cuda graph input_buffers['tile_scheduler_metadata'].copy_(tile_scheduler_metadata) input_buffers['num_splits'][:new_batch_size + 1].copy_(num_splits[:new_batch_size + 1]) diff --git a/lmdeploy/pytorch/nn/attention.py b/lmdeploy/pytorch/nn/attention.py index 7a1654db4b..e2b4c5c191 100644 --- a/lmdeploy/pytorch/nn/attention.py +++ b/lmdeploy/pytorch/nn/attention.py @@ -77,9 +77,15 @@ def forward( k_scales_zeros: torch.Tensor = None, v_scales_zeros: torch.Tensor = None, s_aux: torch.Tensor = None, + nsa_indices: torch.Tensor = None, inplace: bool = True, ) -> torch.Tensor: """forward.""" + kwargs = dict() + if nsa_indices is not None: + kwargs['nsa_indices'] = nsa_indices + if s_aux is not None: + kwargs['learnable_sink'] = s_aux return self.impl.forward( query, key, @@ -89,8 +95,8 @@ def forward( attn_metadata=attn_metadata, k_scales_zeros=k_scales_zeros, v_scales_zeros=v_scales_zeros, - learnable_sink=s_aux, inplace=inplace, + **kwargs, ) @staticmethod diff --git a/lmdeploy/pytorch/nn/nsa.py b/lmdeploy/pytorch/nn/nsa.py new file mode 100644 index 0000000000..d3944bf003 --- /dev/null +++ b/lmdeploy/pytorch/nn/nsa.py @@ -0,0 +1,44 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from torch import Tensor, nn + +from lmdeploy.pytorch.backends import OpType, get_backend +from lmdeploy.pytorch.backends.attention import AttentionMetadata +from lmdeploy.pytorch.backends.nsa import NSAIndexMeta +from lmdeploy.pytorch.model_inputs import get_step_ctx_manager + + +class IndexerTopKFP8(nn.Module): + + def __init__(self, topk: int, softmax_scale: float, block_size: int = 128, fill: int = -1): + super().__init__() + backend = get_backend() + index_builder = backend.get_layer_impl_builder(OpType.NSAIndexFP8) + self.index_impl = index_builder.build(topk, softmax_scale, block_size, fill) + + def forward( + self, + q: Tensor, + k: Tensor, + weights: Tensor, + k_cache: Tensor, + k_s_cache: Tensor, + attn_metadata: AttentionMetadata = None, + ): + """forward.""" + step_ctx = get_step_ctx_manager().current_context() + cache_config = step_ctx.cache_config + max_tokens = cache_config.block_size * cache_config.num_gpu_blocks + is_decoding = attn_metadata.is_decoding + if q.size(0) == attn_metadata.kv_seqlens.size(0): + is_decoding = True + max_q_seqlen = 1 if is_decoding else q.size(0) + # we need to make max_kv_seqlen=max_allocated_cache_len to enable cudagraph + max_kv_seqlen = max_tokens if is_decoding else attn_metadata.kv_flatten_size + meta = NSAIndexMeta(cu_seqlen_q=attn_metadata.cu_seqlens_q, + q_seqlens=attn_metadata.q_seqlens, + k_seqlens=attn_metadata.kv_seqlens, + block_offset=attn_metadata.block_offsets, + max_q_seqlen=max_q_seqlen, + max_kv_seqlen=max_kv_seqlen) + ret = self.index_impl.forward(q, k, weights, k_cache, k_s_cache, meta=meta) + return ret diff --git a/lmdeploy/pytorch/strategies/base/model_inputs.py b/lmdeploy/pytorch/strategies/base/model_inputs.py index 795220bd02..77cda0da20 100644 --- a/lmdeploy/pytorch/strategies/base/model_inputs.py +++ b/lmdeploy/pytorch/strategies/base/model_inputs.py @@ -38,6 +38,7 @@ def make_dummy_inputs(batch_size: int, max_kv_seqlen=max_kv_seqlen, sum_kv_seqlen=num_tokens, local_adapter_ids=local_adapter_ids, + is_dummy=True, ) diff --git a/lmdeploy/pytorch/third_party/deep_gemm/__init__.py b/lmdeploy/pytorch/third_party/deep_gemm/__init__.py index 1e734c4073..369862e60e 100644 --- a/lmdeploy/pytorch/third_party/deep_gemm/__init__.py +++ b/lmdeploy/pytorch/third_party/deep_gemm/__init__.py @@ -42,3 +42,43 @@ def fp8_gemm_nt(a, b, d, c, recipe=None, compiled_dim='nk', disable_ue8m0_cast=F N, _ = b[0].shape with _log_jit_build(M, N, K): gemm_fp8_fp8_bf16_nt(a, b, d) + + +try: + from deep_gemm import m_grouped_fp8_gemm_nt_contiguous +except Exception: + from deep_gemm import m_grouped_gemm_fp8_fp8_bf16_nt_contiguous + + def m_grouped_fp8_gemm_nt_contiguous(a, b, d, m_indices, recipe=None, compiled_dims='nk', disable_ue8m0_cast=False): + assert recipe is None + assert compiled_dims == 'nk' + assert disable_ue8m0_cast is False + return m_grouped_gemm_fp8_fp8_bf16_nt_contiguous(a, b, d, m_indices) + + +try: + from deep_gemm import m_grouped_fp8_gemm_nt_masked +except Exception: + from deep_gemm import m_grouped_gemm_fp8_fp8_bf16_nt_masked + + def m_grouped_fp8_gemm_nt_masked(a, + b, + d, + masked_m, + expected_m, + recipe=None, + compiled_dims='nk', + disable_ue8m0_cast=False): + assert recipe is None + assert compiled_dims == 'nk' + assert disable_ue8m0_cast is False + return m_grouped_gemm_fp8_fp8_bf16_nt_masked(a, b, d, masked_m, expected_m) + + +try: + from deep_gemm import get_mn_major_tma_aligned_tensor +except Exception: + from deep_gemm import get_col_major_tma_aligned_tensor + + def get_mn_major_tma_aligned_tensor(x): + return get_col_major_tma_aligned_tensor(x) diff --git a/lmdeploy/pytorch/transformers/__init__.py b/lmdeploy/pytorch/transformers/__init__.py new file mode 100644 index 0000000000..bfafdb1899 --- /dev/null +++ b/lmdeploy/pytorch/transformers/__init__.py @@ -0,0 +1,35 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from functools import lru_cache + +from transformers import AutoConfig + +from lmdeploy.utils import get_logger + + +@lru_cache() +def register_config(model_type: str): + if model_type == 'deepseek_v32': + from lmdeploy.pytorch.transformers.configuration_deepseek_v32 import DeepseekV32Config + AutoConfig.register(DeepseekV32Config.model_type, DeepseekV32Config) + else: + logger.debug(f'Can not register config for model_type: {model_type}') + + +logger = get_logger('lmdeploy') + + +def config_from_pretrained(pretrained_model_name_or_path: str, **kwargs): + try: + return AutoConfig.from_pretrained(pretrained_model_name_or_path, **kwargs) + except ValueError as e: + logger.debug(f'AutoConfig.from_pretrained failed: {e}, try register config manually.') + # some models (dsv32) does not provide auto map for config + from transformers import PretrainedConfig + trust_remote_code = kwargs.pop('trust_remote_code', None) + config_dict, _ = PretrainedConfig.get_config_dict(pretrained_model_name_or_path, **kwargs) + model_type = config_dict.get('model_type', None) + if trust_remote_code is not None: + kwargs['trust_remote_code'] = trust_remote_code + register_config(model_type) + + return AutoConfig.from_pretrained(pretrained_model_name_or_path, **kwargs) diff --git a/lmdeploy/pytorch/transformers/configuration_deepseek_v32.py b/lmdeploy/pytorch/transformers/configuration_deepseek_v32.py new file mode 100644 index 0000000000..59f91fb3ef --- /dev/null +++ b/lmdeploy/pytorch/transformers/configuration_deepseek_v32.py @@ -0,0 +1,13 @@ +# Copyright (c) OpenMMLab. All rights reserved. + +from transformers.models.deepseek_v3.configuration_deepseek_v3 import DeepseekV3Config + + +class DeepseekV32Config(DeepseekV3Config): + model_type = 'deepseek_v32' + + def __init__(self, index_head_dim=128, index_n_heads=64, index_topk=2048, **kwargs): + super().__init__(**kwargs) + self.index_head_dim = index_head_dim + self.index_n_heads = index_n_heads + self.index_topk = index_topk diff --git a/tests/pytorch/kernel/test_bitonic_topk.py b/tests/pytorch/kernel/test_bitonic_topk.py new file mode 100644 index 0000000000..dd58f73361 --- /dev/null +++ b/tests/pytorch/kernel/test_bitonic_topk.py @@ -0,0 +1,67 @@ +import pytest +import torch + + +class TestBitonicTopk: + + @pytest.fixture + def device(self): + yield 'cuda' + + @pytest.fixture + def k(self): + yield 2048 + + @pytest.fixture + def q_seqlens(self, device): + ret = [4, 16, 1, 32] + ret = torch.tensor(ret, dtype=torch.int32, device=device) + yield ret + + @pytest.fixture + def kv_seqlens(self, device): + ret = [1024, 2048, 4096, 4096 + 133] + ret = torch.tensor(ret, dtype=torch.int32, device=device) + yield ret + + @pytest.fixture + def batch_size(self, kv_seqlens): + return kv_seqlens.numel() + + @pytest.fixture + def max_kv_len(self, kv_seqlens): + return kv_seqlens.max().item() + + @pytest.fixture + def scores(self, q_seqlens, max_kv_len, device): + num_tokens = q_seqlens.sum().item() + yield torch.randn((num_tokens, max_kv_len), device=device) + + @pytest.fixture + def gt(self, scores, q_seqlens, kv_seqlens, k): + batch_size = kv_seqlens.numel() + num_tokens, _ = scores.shape + topk_indices = torch.empty((num_tokens, k), dtype=torch.int32, device=scores.device) + topk_indices.fill_(-1) + + start = 0 + for i in range(batch_size): + q_seqlen = q_seqlens[i].item() + seqlen = kv_seqlens[i].item() + tmp_k = min(seqlen, k) + end = start + q_seqlen + _, topk_indices[start:end, :seqlen] = torch.topk(scores[start:end, :seqlen], + tmp_k, + largest=True, + sorted=True) + start = end + return topk_indices + + def test_bitonic_topk(self, scores, q_seqlens, kv_seqlens, k, gt): + from lmdeploy.pytorch.kernels.cuda.bitonic_topk import bitonic_topk + out = bitonic_topk(scores, q_seqlens=q_seqlens, kv_seqlens=kv_seqlens, k=k, fill=-1) + gt[gt < 0] = 0 + out[out < 0] = 0 + gt_score = torch.gather(scores, 1, gt.to(torch.int64)) + out_score = torch.gather(scores, 1, out.to(torch.int64)) + torch.testing.assert_close(gt_score, out_score) diff --git a/tests/pytorch/kernel/test_ds_index.py b/tests/pytorch/kernel/test_ds_index.py new file mode 100644 index 0000000000..dc5ef009f5 --- /dev/null +++ b/tests/pytorch/kernel/test_ds_index.py @@ -0,0 +1,146 @@ +import pytest +import torch + + +def _make_A(M, K, group_size, out_dtype, device): + quant_A = torch.randn(M, K // group_size, group_size, dtype=torch.float32, device=device) + # -1 ~ 1 + quant_A = quant_A * 2 - 1 + # scaling abs max to fmax + finfo = torch.finfo(out_dtype) + fmax = finfo.max + scaling = fmax / quant_A.abs().amax(-1, keepdim=True) + quant_A *= scaling + quant_A = quant_A.to(out_dtype).to(torch.float32) + + # create scale and A + scale = torch.randn(M, K // group_size, dtype=torch.float32, device=device) + scale /= fmax + A = quant_A * scale[..., None] + + A = A.reshape(M, K) + quant_A = quant_A.reshape(M, K).to(out_dtype) + scale = scale.T.contiguous().T + return A, quant_A, scale + + +@pytest.mark.skipif(torch.cuda.get_device_capability()[0] < 9, reason='require device with cc>=9.0') +class TestDSIndex: + + @pytest.fixture + def num_heads(self): + yield 64 + + @pytest.fixture + def head_dim(self): + yield 128 + + @pytest.fixture + def block_size(self): + yield 64 + + @pytest.fixture + def device(self): + yield 'cuda' + + @pytest.fixture + def q_seqlens(self, request): + yield request.param + + @pytest.fixture + def kv_seqlens(self, request): + yield request.param + + @pytest.fixture + def k_seqlens(self, kv_seqlens, device): + yield torch.tensor(kv_seqlens, dtype=torch.int32, device=device) + + @pytest.fixture + def cu_seqlen_q(self, q_seqlens, device): + yield torch.tensor([0] + list(q_seqlens), dtype=torch.int32, device=device).cumsum(0) + + @pytest.fixture + def cu_seqlen_kv(self, kv_seqlens, device): + yield torch.tensor([0] + list(kv_seqlens), dtype=torch.int32, device=device).cumsum(0) + + @pytest.fixture + def query(self, q_seqlens, num_heads, head_dim, device): + total_len = sum(q_seqlens) + fp_q, q, q_s = _make_A(total_len * num_heads, head_dim, head_dim, out_dtype=torch.float8_e4m3fn, device=device) + fp_q = fp_q.view(total_len, num_heads, head_dim) + q = q.view(total_len, num_heads, head_dim) + q_s = q_s.view(total_len, num_heads) + yield fp_q, q, q_s + + @pytest.fixture + def q(self, query): + yield query[1] + + @pytest.fixture + def q_s(self, query): + yield query[2] + + @pytest.fixture + def key(self, kv_seqlens, head_dim): + total_len = sum(kv_seqlens) + fp_k, k, k_s = _make_A(total_len, head_dim, head_dim, out_dtype=torch.float8_e4m3fn, device='cuda') + fp_k = fp_k.view(total_len, head_dim) + k = k.view(total_len, head_dim) + k_s = k_s.view(total_len) + yield fp_k, k, k_s + + @pytest.fixture + def k(self, key): + yield key[1] + + @pytest.fixture + def k_s(self, key): + yield key[2] + + @pytest.fixture + def cache_key(self, k, k_s, kv_seqlens, block_size, head_dim): + batch_size = len(kv_seqlens) + max_num_blocks = (max(kv_seqlens) + block_size - 1) // block_size + + # get block offsets + batch_ids = torch.arange(batch_size, device='cuda') * max_num_blocks + block_ids = torch.arange(max_num_blocks, device='cuda') + block_offsets = (batch_ids[:, None] + block_ids[None, :]) + + k_cache = torch.zeros((max_num_blocks * batch_size * block_size, head_dim), + dtype=torch.float8_e4m3fn, + device='cuda') + k_s_cache = torch.zeros((max_num_blocks * batch_size * block_size), dtype=torch.float32, device='cuda') + + k = k.split(kv_seqlens, dim=0) + k_s = k_s.split(kv_seqlens, dim=0) + for i in range(batch_size): + size = k[i].size(0) + start = i * max_num_blocks * block_size + end = start + size + k_cache[start:end] = k[i] + k_s_cache[start:end] = k_s[i] + + k_cache = k_cache.view(batch_size * max_num_blocks, block_size, head_dim) + k_s_cache = k_s_cache.view(batch_size * max_num_blocks, block_size) + + yield k_cache, k_s_cache, block_offsets + + @pytest.fixture + def k_cache(self, cache_key): + yield cache_key[0] + + @pytest.fixture + def k_s_cache(self, cache_key): + yield cache_key[1] + + @pytest.fixture + def block_offset(self, cache_key): + yield cache_key[2] + + @pytest.mark.parametrize('q_seqlens', [(1, 1, 1, 1), (1024, 2048, 1024, 1)], indirect=True) + @pytest.mark.parametrize('kv_seqlens', [(2048, 4096, 1024, 128)], indirect=True) + def test_fp8_index(self, q, q_s, k_cache, k_s_cache, cu_seqlen_q, k_seqlens, block_offset): + # gt requires tilelang, so this test just ensure the kernel works + from lmdeploy.pytorch.kernels.cuda.ds_index import fp8_index + fp8_index(q, q_s, k_cache, k_s_cache, cu_seqlen_q, k_seqlens, block_offset) diff --git a/tests/pytorch/kernel/test_fill_kv_cache.py b/tests/pytorch/kernel/test_fill_kv_cache.py index 51a9bf32ac..1a5bc565a0 100644 --- a/tests/pytorch/kernel/test_fill_kv_cache.py +++ b/tests/pytorch/kernel/test_fill_kv_cache.py @@ -265,3 +265,128 @@ def test_fill_kv_cache(self, k_states, v_states, k_caches, v_caches, k_scales_ze torch.testing.assert_close(v_scales_zeros, gt[3]) torch.testing.assert_close(k_caches, gt[0]) torch.testing.assert_close(v_caches, gt[1]) + + +@pytest.mark.skipif(torch.cuda.get_device_capability()[0] < 9, reason='require device with cc>=9.0') +class TestFillKVCacheBlockedFP8(TestFillKVCache): + + @pytest.fixture + def quant_dtype(self): + yield torch.float8_e4m3fn + + @pytest.fixture + def num_heads(self): + yield 4 + + @pytest.fixture + def head_dim(self): + yield 128 + + @pytest.fixture + def block_size(self): + yield 64 + + @pytest.fixture + def group_size(self): + yield 128 + + @pytest.fixture + def cu_seqlen_q(self, q_start_loc, q_seq_length): + batch_size = q_start_loc.size(0) + cu_seqlen = torch.zeros(batch_size + 1, dtype=torch.int32).cuda() + cu_seqlen[1:] = q_start_loc + q_seq_length + return cu_seqlen + + @pytest.fixture + def k_caches(self, batch_size, max_num_blocks, block_size, num_heads, head_dim, quant_dtype): + shape = (batch_size * max_num_blocks, block_size, num_heads, head_dim) + yield torch.full(shape, 0, dtype=quant_dtype).cuda() + + @pytest.fixture + def v_caches(self, k_caches): + yield torch.zeros_like(k_caches) + + @pytest.fixture + def ks_caches(self, batch_size, max_num_blocks, block_size, num_heads, head_dim, group_size): + shape = (batch_size * max_num_blocks, block_size, num_heads, head_dim // group_size) + yield torch.full(shape, 0.0).cuda() + + @pytest.fixture + def vs_caches(self, ks_caches): + yield torch.ones_like(ks_caches) + + @pytest.fixture + def gt(self, k_states, v_states, group_size, quant_dtype): + from lmdeploy.pytorch.kernels.cuda.blocked_gemm_fp8 import quant_fp8 + batch_size = k_states.size(0) + num_heads = k_states.size(1) + head_dim = k_states.size(2) + + k_states = k_states.flatten(0, -2) + v_states = v_states.flatten(0, -2) + quant_k, quant_ks = quant_fp8(k_states, group_size=group_size, dtype=quant_dtype) + quant_v, quant_vs = quant_fp8(v_states, group_size=group_size, dtype=quant_dtype) + + quant_k = quant_k.view(batch_size, num_heads, head_dim) + quant_ks = quant_ks.view(batch_size, num_heads, head_dim // group_size) + quant_v = quant_v.view(batch_size, num_heads, head_dim) + quant_vs = quant_vs.view(batch_size, num_heads, head_dim // group_size) + + yield quant_k, quant_ks, quant_v, quant_vs + + def uncache(self, k_caches, ks_caches, v_caches, vs_caches, cu_seqlen_q, kv_seqlens, block_offsets): + batch_size = block_offsets.size(0) + out_k = [] + out_ks = [] + out_v = [] + out_vs = [] + q_seqlens = cu_seqlen_q[1:] - cu_seqlen_q[:-1] + for bidx in range(batch_size): + seqlen = q_seqlens[bidx].item() + kv_len = kv_seqlens[bidx].item() + start = kv_len - seqlen + end = kv_len + k = k_caches[block_offsets[bidx]].reshape(-1, k_caches.size(-2), k_caches.size(-1)) + ks = ks_caches[block_offsets[bidx]].reshape(-1, ks_caches.size(-2), ks_caches.size(-1)) + v = v_caches[block_offsets[bidx]].reshape(-1, v_caches.size(-2), v_caches.size(-1)) + vs = vs_caches[block_offsets[bidx]].reshape(-1, vs_caches.size(-2), vs_caches.size(-1)) + out_k.append(k[start:end]) + out_ks.append(ks[start:end]) + out_v.append(v[start:end]) + out_vs.append(vs[start:end]) + out_k = torch.cat(out_k, dim=0) + out_ks = torch.cat(out_ks, dim=0) + out_v = torch.cat(out_v, dim=0) + out_vs = torch.cat(out_vs, dim=0) + return out_k, out_ks, out_v, out_vs + + @pytest.mark.parametrize(['seq_lens', 'history_lens'], [ + ((1, 1, 1, 1), (1, 128, 256, 200)), + ((1, 64, 128, 50), (1, 128, 256, 200)), + ], + indirect=True) + def test_fill_kv_cache(self, k_states, v_states, k_caches, v_caches, ks_caches, vs_caches, block_offsets, + cu_seqlen_q, kv_seq_length, max_q_seq_length, gt, group_size): + from lmdeploy.pytorch.kernels.cuda.fill_kv_cache import fill_kv_cache_blocked_fp8 + fill_kv_cache_blocked_fp8(k_states, + v_states, + k_caches, + v_caches, + ks_caches, + vs_caches, + cu_seqlen_q, + kv_seq_length, + max_q_seq_length, + block_offsets=block_offsets, + group_size=group_size) + + gt_k, gt_ks, gt_v, gt_vs = gt + + # uncache + out_k, out_ks, out_v, out_vs = self.uncache(k_caches, ks_caches, v_caches, vs_caches, cu_seqlen_q, + kv_seq_length, block_offsets) + + torch.testing.assert_close(out_k.float(), gt_k.float()) + torch.testing.assert_close(out_ks, gt_ks) + torch.testing.assert_close(out_v.float(), gt_v.float()) + torch.testing.assert_close(out_vs, gt_vs) diff --git a/tests/pytorch/kernel/test_flatten_kv_cache.py b/tests/pytorch/kernel/test_flatten_kv_cache.py index 5e2fffb510..80cfc34c84 100644 --- a/tests/pytorch/kernel/test_flatten_kv_cache.py +++ b/tests/pytorch/kernel/test_flatten_kv_cache.py @@ -170,3 +170,75 @@ def atol(self): @pytest.fixture def rtol(self): yield 1e-3 + + +@pytest.mark.skipif(torch.cuda.get_device_capability()[0] < 9, reason='require device with cc>=9.0') +class TestFlattenKVCacheMLAFP8(TestFlattenKVCache): + + @pytest.fixture + def out_dtype(self): + yield torch.bfloat16 + + @pytest.fixture + def num_heads(self): + yield 1 + + @pytest.fixture + def head_dim(self): + yield 576 + + @pytest.fixture + def block_size(self): + yield 64 + + @pytest.fixture + def k_cache_mla(self, k_caches): + from lmdeploy.pytorch.kernels.cuda.blocked_gemm_fp8 import quant_fp8 + num_blocks, block_size, num_heads, _ = k_caches.shape + k_cache_pe = k_caches[:, :, :, 512:] + k_cache_nope = k_caches[:, :, :, :512].flatten(0, -2) + k_cache_nope, k_cache_scale = quant_fp8(k_cache_nope, group_size=128) + k_cache_nope = k_cache_nope.view(num_blocks, block_size, num_heads, -1) + k_cache_scale = k_cache_scale.reshape(num_blocks, block_size, num_heads, -1).to(torch.float32) + dtype = k_cache_nope.dtype + out = torch.cat([k_cache_nope, k_cache_scale.view(dtype), k_cache_pe.view(dtype)], dim=-1) + yield out + + def _dequant(self, k_cache_mla): + k_cache_nope = k_cache_mla[..., :512].to(torch.float32) + k_cache_scale = k_cache_mla[..., 512:512 + 16].view(torch.float32) + k_cache_pe = k_cache_mla[..., 512 + 16:].view(torch.bfloat16) + k_cache_nope = k_cache_nope.unflatten(-1, (-1, 128)) + k_cache_scale = k_cache_scale[..., None] + k_cache_nope *= k_cache_scale + k_cache_nope = k_cache_nope.flatten(-2, -1).to(k_cache_pe.dtype) + k_cache = torch.cat([k_cache_nope, k_cache_pe], dim=-1) + return k_cache + + @pytest.fixture + def gt(self, k_cache_mla, kv_lens, block_offsets, block_size, num_heads, out_size, head_dim): + k_caches = self._dequant(k_cache_mla) + k_states = k_caches.new_empty(num_heads, out_size, head_dim) + start_loc = 0 + for kv_len, block_offs in zip(kv_lens, block_offsets): + remain_len = kv_len + for idx, _ in enumerate(range(0, kv_len, block_size)): + b_off = block_offs[idx] + block_len = min(block_size, remain_len) + end_loc = start_loc + block_len + k_block = k_caches[b_off, :block_len] + k_states[:, start_loc:end_loc] = k_block.transpose(0, 1) + start_loc = end_loc + remain_len -= block_len + + yield k_states + + def test_flatten_kv_cache(self, k_cache_mla, kv_seqlens, block_offsets, out_size, out_dtype, gt): + from lmdeploy.pytorch.kernels.cuda.flatten_kv_cache import flatten_kv_cache_mla_fp8 + + k_states = flatten_kv_cache_mla_fp8(k_cache_mla, + kv_seqlens, + block_offsets, + out_size=out_size, + out_dtype=out_dtype) + torch.testing.assert_close(k_states, gt)