diff --git a/README.md b/README.md
index 8bd7fcbc15..5138e0a455 100644
--- a/README.md
+++ b/README.md
@@ -150,6 +150,7 @@ LMDeploy is a toolkit for compressing, deploying, and serving LLM, developed by
Phi-3.5-MoE (16x3.8B)
Phi-4-mini (3.8B)
MiniCPM3 (4B)
+ SDAR (1.7B-30B)
gpt-oss (20B, 120B)
diff --git a/README_ja.md b/README_ja.md
index 3537e84935..75d05390ad 100644
--- a/README_ja.md
+++ b/README_ja.md
@@ -137,6 +137,7 @@ LMDeploy TurboMindエンジンは卓越した推論能力を持ち、さまざ
Phi-3.5-MoE (16x3.8B)
Phi-4-mini (3.8B)
MiniCPM3 (4B)
+ SDAR (1.7B-30B)
diff --git a/README_zh-CN.md b/README_zh-CN.md
index 7d7635dc70..179ffa466d 100644
--- a/README_zh-CN.md
+++ b/README_zh-CN.md
@@ -151,6 +151,7 @@ LMDeploy TurboMind 引擎拥有卓越的推理能力,在各种规模的模型
Phi-3.5-MoE (16x3.8B)
Phi-4-mini (3.8B)
MiniCPM3 (4B)
+ SDAR (1.7B-30B)
gpt-oss (20B, 120B)
|
diff --git a/benchmark/profile_throughput.py b/benchmark/profile_throughput.py
index e2a4f724c6..6e4243cca6 100644
--- a/benchmark/profile_throughput.py
+++ b/benchmark/profile_throughput.py
@@ -307,6 +307,10 @@ def parse_args():
# pytorch engine args
pt_group = parser.add_argument_group('PyTorch engine arguments')
ArgumentHelper.eager_mode(pt_group)
+ ArgumentHelper.dllm_block_length(pt_group)
+ ArgumentHelper.dllm_unmasking_strategy(pt_group)
+ ArgumentHelper.dllm_denoising_steps(pt_group)
+ ArgumentHelper.dllm_confidence_threshold(pt_group)
tp_act = ArgumentHelper.tp(pt_group)
cache_count_act = ArgumentHelper.cache_max_entry_count(pt_group)
@@ -363,6 +367,10 @@ def main():
quant_policy=args.quant_policy,
dtype=args.dtype,
distributed_executor_backend=args.distributed_executor_backend,
+ dllm_block_length=args.dllm_block_length,
+ dllm_unmasking_strategy=args.dllm_unmasking_strategy,
+ dllm_denoising_steps=args.dllm_denoising_steps,
+ dllm_confidence_threshold=args.dllm_confidence_threshold,
)
if args.use_uvloop:
diff --git a/docs/en/supported_models/supported_models.md b/docs/en/supported_models/supported_models.md
index d4ffba4787..aa28854d8a 100644
--- a/docs/en/supported_models/supported_models.md
+++ b/docs/en/supported_models/supported_models.md
@@ -120,6 +120,7 @@ The following tables detail the models supported by LMDeploy's TurboMind engine
| Phi-3.5-mini | 3.8B | LLM | Yes | Yes | No | - | - |
| Phi-3.5-MoE | 16x3.8B | LLM | Yes | Yes | No | - | - |
| Phi-3.5-vision | 4.2B | MLLM | Yes | Yes | No | - | - |
+| SDAR | 1.7B-30B | LLM | Yes | Yes | No | - | - |
```{note}
* [1] Currently Mono-InternVL does not support FP16 due to numerical instability. Please use BF16 instead.
diff --git a/docs/zh_cn/supported_models/supported_models.md b/docs/zh_cn/supported_models/supported_models.md
index d3d6946020..8e9e3fef20 100644
--- a/docs/zh_cn/supported_models/supported_models.md
+++ b/docs/zh_cn/supported_models/supported_models.md
@@ -120,6 +120,7 @@
| Phi-3.5-mini | 3.8B | LLM | Yes | Yes | No | - | - |
| Phi-3.5-MoE | 16x3.8B | LLM | Yes | Yes | No | - | - |
| Phi-3.5-vision | 4.2B | MLLM | Yes | Yes | No | - | - |
+| SDAR | 1.7B-30B | LLM | Yes | Yes | No | - | - |
```{note}
* [1] 目前,Mono-InternVL不支持FP16,因为数值不稳定。请改用BF16
diff --git a/lmdeploy/cli/cli.py b/lmdeploy/cli/cli.py
index 916b91e527..d71198791f 100644
--- a/lmdeploy/cli/cli.py
+++ b/lmdeploy/cli/cli.py
@@ -55,6 +55,7 @@ def add_parser_chat():
ArgumentHelper.adapters(pt_group)
ArgumentHelper.device(pt_group)
ArgumentHelper.eager_mode(pt_group)
+ ArgumentHelper.dllm_block_length(pt_group)
# common engine args
dtype_act = ArgumentHelper.dtype(pt_group)
tp_act = ArgumentHelper.tp(pt_group)
diff --git a/lmdeploy/cli/serve.py b/lmdeploy/cli/serve.py
index 7fac263001..6a9e9f2b13 100644
--- a/lmdeploy/cli/serve.py
+++ b/lmdeploy/cli/serve.py
@@ -92,6 +92,10 @@ def add_parser_api_server():
ArgumentHelper.eager_mode(pt_group)
ArgumentHelper.disable_vision_encoder(pt_group)
ArgumentHelper.logprobs_mode(pt_group)
+ ArgumentHelper.dllm_block_length(pt_group)
+ ArgumentHelper.dllm_unmasking_strategy(pt_group)
+ ArgumentHelper.dllm_denoising_steps(pt_group)
+ ArgumentHelper.dllm_confidence_threshold(pt_group)
# common engine args
dtype_act = ArgumentHelper.dtype(pt_group)
@@ -219,6 +223,10 @@ def api_server(args):
hf_overrides=args.hf_overrides,
disable_vision_encoder=args.disable_vision_encoder,
logprobs_mode=args.logprobs_mode,
+ dllm_block_length=args.dllm_block_length,
+ dllm_unmasking_strategy=args.dllm_unmasking_strategy,
+ dllm_denoising_steps=args.dllm_denoising_steps,
+ dllm_confidence_threshold=args.dllm_confidence_threshold,
)
else:
from lmdeploy.messages import TurbomindEngineConfig
diff --git a/lmdeploy/cli/utils.py b/lmdeploy/cli/utils.py
index 4f852b076d..bfd94182d0 100644
--- a/lmdeploy/cli/utils.py
+++ b/lmdeploy/cli/utils.py
@@ -624,6 +624,36 @@ def logprobs_mode(parser):
choices=[None, 'raw_logits', 'raw_logprobs'],
help='The mode of logprobs.')
+ @staticmethod
+ def dllm_block_length(parser):
+ """dllm_block_length for dllm."""
+ return parser.add_argument('--dllm-block-length', type=int, default=None, help='Block length for dllm')
+
+ @staticmethod
+ def dllm_unmasking_strategy(parser):
+ """Dllm unmasking strategy."""
+ return parser.add_argument('--dllm-unmasking-strategy',
+ type=str,
+ default='low_confidence_dynamic',
+ choices=['low_confidence_dynamic', 'low_confidence_static', 'sequential'],
+ help='The unmasking strategy for dllm.')
+
+ @staticmethod
+ def dllm_denoising_steps(parser):
+ """Dllm denoising steps."""
+ return parser.add_argument('--dllm-denoising-steps',
+ type=int,
+ default=None,
+ help='The number of denoising steps for dllm.')
+
+ @staticmethod
+ def dllm_confidence_threshold(parser):
+ """Dllm confidence threshold."""
+ return parser.add_argument('--dllm-confidence-threshold',
+ type=float,
+ default=0.85,
+ help='The confidence threshold for dllm.')
+
# adapted from https://github.com/vllm-project/vllm/blob/main/vllm/utils/__init__.py
class FlexibleArgumentParser(argparse.ArgumentParser):
diff --git a/lmdeploy/messages.py b/lmdeploy/messages.py
index 808c47737c..69ac361abc 100644
--- a/lmdeploy/messages.py
+++ b/lmdeploy/messages.py
@@ -335,6 +335,12 @@ class PytorchEngineConfig:
disable_vision_encoder (bool): Whether to disable loading vision
encoder. Default to False.
logprobs_mode (str): The mode of logprob, options: ['raw_logits', 'raw_logprobs']
+ dllm_block_length (int): Block size of block diffusion model.
+ dllm_unmasking_strategy (str): Dllm unmasking strategy, options:
+ ['low_confidence_dynamic', 'low_confidence_static', 'sequential'].
+ dllm_denoising_steps (int): Dllm denoising steps.
+ dllm_confidence_threshold (float): dllm unmasking threshold for
+ dynamic unmasking.
"""
dtype: str = 'auto'
tp: int = 1
@@ -370,6 +376,12 @@ class PytorchEngineConfig:
disable_vision_encoder: bool = False
logprobs_mode: str = None
+ # dllm
+ dllm_block_length: int = None
+ dllm_unmasking_strategy: str = 'low_confidence_dynamic'
+ dllm_denoising_steps: int = None
+ dllm_confidence_threshold: float = 0.85
+
role: EngineRole = EngineRole.Hybrid
migration_backend: MigrationBackend = MigrationBackend.DLSlime
diff --git a/lmdeploy/metrics/stats.py b/lmdeploy/metrics/stats.py
index b21eeaca90..6654b3d9d3 100644
--- a/lmdeploy/metrics/stats.py
+++ b/lmdeploy/metrics/stats.py
@@ -198,7 +198,10 @@ def update_from_output(self, outputs: EngineOutput, req_state: RequestState):
outputs (EngineOutput): The output from the engine containing information about the current iteration.
req_state (RequestState): The state of the request, including timestamps and token counts.
"""
- self.new_generation_tokens = outputs.num_token - req_state.generation_tokens
+ new_generation_tokens = outputs.num_token - req_state.generation_tokens
+ if new_generation_tokens == 0:
+ return
+ self.new_generation_tokens = new_generation_tokens
if req_state.first_token_time == 0:
# It means the first token is generated in this iteration
req_state.first_token_time = outputs.req_metrics.token_timestamp
diff --git a/lmdeploy/model.py b/lmdeploy/model.py
index 20c48c880a..2b9b3743f1 100644
--- a/lmdeploy/model.py
+++ b/lmdeploy/model.py
@@ -737,7 +737,7 @@ class HFChatTemplate(BaseChatTemplate):
def __init__(self, model_path: str = '', **kwargs):
try:
- from transformers import AutoTokenizer, PretrainedConfig
+ from transformers import AutoConfig, AutoTokenizer, PretrainedConfig
self.tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
self.system_start, self.system_end = self._role_instruction('system')
self.user_start, self.user_end = self._role_instruction('user')
@@ -747,7 +747,10 @@ def __init__(self, model_path: str = '', **kwargs):
self.stop_words.append(self.tokenizer.eos_token)
if hasattr(self.tokenizer, 'eot_token') and self.tokenizer.eot_token is not None:
self.stop_words.append(self.tokenizer.eot_token)
- cfg = PretrainedConfig.from_pretrained(model_path, trust_remote_code=True)
+ try:
+ cfg = AutoConfig.from_pretrained(model_path, trust_remote_code=True)
+ except Exception as e: # noqa
+ cfg = PretrainedConfig.from_pretrained(model_path, trust_remote_code=True)
self.is_gpt_oss = getattr(cfg, 'architectures', [''])[0] == 'GptOssForCausalLM'
if self.is_gpt_oss:
self.stop_words.append('<|call|>')
diff --git a/lmdeploy/pytorch/backends/attention.py b/lmdeploy/pytorch/backends/attention.py
index 0792c0ccca..0c842e8871 100644
--- a/lmdeploy/pytorch/backends/attention.py
+++ b/lmdeploy/pytorch/backends/attention.py
@@ -93,6 +93,7 @@ def build(
causal: bool = True,
use_flash_mla: bool = False,
learnable_sink: bool = False,
+ block_sparse_size: int = 1,
**kwargs,
) -> AttentionImpl[T]:
"""build."""
diff --git a/lmdeploy/pytorch/backends/cuda/attention.py b/lmdeploy/pytorch/backends/cuda/attention.py
index 2a9fcc5201..b241c384b2 100644
--- a/lmdeploy/pytorch/backends/cuda/attention.py
+++ b/lmdeploy/pytorch/backends/cuda/attention.py
@@ -1,5 +1,6 @@
# Copyright (c) OpenMMLab. All rights reserved.
+import functools
from dataclasses import dataclass
from typing import Literal
@@ -20,8 +21,8 @@
assert torch.ops.flash_attn_3 is not None
use_fa3 = True
except Exception:
- logger.warning('For higher performance, please install FlashAttention-3 '
- 'https://github.com/Dao-AILab/flash-attention')
+ logger.debug('For higher performance, please install FlashAttention-3 '
+ 'https://github.com/Dao-AILab/flash-attention')
@dataclass
@@ -62,6 +63,7 @@ def __init__(
sliding_window: int = None,
logit_softcapping: float = None,
causal: bool = True,
+ block_sparse_size: int = 1,
**kwargs,
):
super().__init__(
@@ -91,6 +93,7 @@ def __init__(
world_size, rank = get_tp_world_rank()
self.alibi_head_offset = self.num_heads * rank
self.alibi_num_heads = self.num_heads * world_size
+ self.block_sparse_size = block_sparse_size
def forward(
self,
@@ -116,7 +119,7 @@ def forward(
kv_flatten_size = attn_metadata.kv_flatten_size
quant_policy = attn_metadata.quant_policy
if attn_metadata.is_decoding:
- max_q_seqlen = 1
+ max_q_seqlen = self.block_sparse_size
else:
max_q_seqlen = query.numel() // (query.size(-1) * query.size(-2))
fill_max_q_seqlen = max_q_seqlen
@@ -213,11 +216,21 @@ def forward(
logit_softcapping=self.logit_softcapping,
sinks=learnable_sink,
causal=self.causal,
+ block_sparse_size=self.block_sparse_size,
)
return attn_output
+@functools.lru_cache
+def use_fa3_warning():
+ if use_fa3:
+ return True
+ logger.warning('For higher performance, please install FlashAttention-3 '
+ 'https://github.com/Dao-AILab/flash-attention')
+ return False
+
+
class FlashMLAImpl(TritonAttentionImpl):
def __init__(
@@ -252,6 +265,7 @@ def __init__(
from lmdeploy.pytorch.kernels.cuda import flash_mla_fwd
self.flash_mla_fwd = flash_mla_fwd
assert num_kv_heads == 1, 'MLA requires num kv heads equal to 1'
+ use_fa3_warning()
def forward(
self,
@@ -512,6 +526,14 @@ def forward(
return attn_output
+@functools.lru_cache
+def _enable_fa3(alibi: bool, learnable_sink: bool, block_sparse_size: int):
+ enable = not alibi and not learnable_sink and block_sparse_size == 1
+ if enable and not use_fa3_warning():
+ enable = False
+ return enable
+
+
class TritonAttentionBuilder(AttentionBuilder[TritonAttentionMetadata]):
"""Triton attention builder."""
@@ -528,10 +550,13 @@ def build(
causal: bool = True,
use_flash_mla: bool = False,
learnable_sink: bool = False,
+ block_sparse_size: int = 1,
**kwargs,
) -> TritonAttentionImpl:
"""build."""
+ enable_fa3 = _enable_fa3(alibi, learnable_sink, block_sparse_size)
if use_flash_mla is True:
+ logger.debug('Build FlashMLAImpl Attention')
return FlashMLAImpl(num_heads,
head_size,
scale=scale,
@@ -542,7 +567,8 @@ def build(
logical_softcapping=logical_softcapping,
causal=causal,
**kwargs)
- elif use_fa3 and not alibi and not learnable_sink:
+ elif enable_fa3:
+ logger.debug('Build FA3Impl Attention')
return FA3Impl(num_heads,
head_size,
scale=scale,
@@ -554,6 +580,7 @@ def build(
causal=causal,
**kwargs)
else:
+ logger.debug('Build TritonAttentionImpl Attention')
return TritonAttentionImpl(num_heads,
head_size,
scale=scale,
@@ -563,4 +590,5 @@ def build(
sliding_window=sliding_window,
logical_softcapping=logical_softcapping,
causal=causal,
+ block_sparse_size=block_sparse_size,
**kwargs)
diff --git a/lmdeploy/pytorch/backends/cuda/graph_runner.py b/lmdeploy/pytorch/backends/cuda/graph_runner.py
index ccc413be75..deb6c66bfd 100644
--- a/lmdeploy/pytorch/backends/cuda/graph_runner.py
+++ b/lmdeploy/pytorch/backends/cuda/graph_runner.py
@@ -9,9 +9,11 @@
from lmdeploy.pytorch.config import BackendConfig, CacheConfig, ModelConfig
from lmdeploy.pytorch.model_inputs import StepContext, get_step_ctx_manager
from lmdeploy.pytorch.models.utils.cudagraph import CudaGraphMeta
+from lmdeploy.pytorch.strategies.base import StrategyFactoryBase
from lmdeploy.utils import get_logger
from ..graph_runner import GraphRunner
+from .attention import TritonAttentionMetadata
logger = get_logger('lmdeploy')
@@ -146,6 +148,11 @@ def __init__(self, model: torch.nn.Module, model_config: ModelConfig, cache_conf
self._runner_map: Dict[Any, CUDASingleGraphRunner] = dict()
self.has_try_compile_model: bool = False
+ # strategy factory
+ build_ctx = model.ctx_mgr.build_ctx
+ strategy_factory: StrategyFactoryBase = build_ctx.strategy_factory
+ self.cudagraph_strategy = strategy_factory.build_cudagraph_strategy()
+
def check_enable_graph(self):
"""Check enable graph."""
if self.backend_config.eager_mode:
@@ -173,18 +180,24 @@ def _get_capture_tokens(self, batch_size: int):
assert False, f'Unsupported batch_size={batch_size}'
def get_graph_key(self, input_ids: torch.Tensor, position_ids: torch.Tensor, past_key_values: List,
- attn_metadata: Any, inputs_embeds: torch.Tensor, **kwargs):
+ attn_metadata: TritonAttentionMetadata, inputs_embeds: torch.Tensor, **kwargs):
"""Get graph key."""
context = self.ctx_mgr.current_context()
is_decoding = context.is_decoding
- num_tokens = input_ids.numel()
+ batch_size = attn_metadata.q_seqlens.size(0)
meta = self.get_meta()
enable_microbatch = get_step_ctx_manager().current_context().enable_microbatch
if meta.padding_batch_size is None:
- new_num_tokens = self._get_capture_tokens(num_tokens)
+ batch_size = self._get_capture_tokens(batch_size)
else:
- new_num_tokens = self._get_capture_tokens(meta.padding_batch_size)
- return (new_num_tokens, is_decoding, enable_microbatch)
+ batch_size = self._get_capture_tokens(meta.padding_batch_size)
+ return (batch_size, is_decoding, enable_microbatch)
+
+ def _get_max_tokens(self, graph_key: tuple):
+ max_batches = graph_key[0]
+ is_decoding = graph_key[1]
+ assert is_decoding
+ return self.cudagraph_strategy.get_max_tokens(max_batches)
def __call__(self, **kwargs):
"""call."""
@@ -198,10 +211,10 @@ def __call__(self, **kwargs):
return self.model(**kwargs)
graph_key = self.get_graph_key(**kwargs)
- max_tokens = graph_key[0]
+ max_batches = graph_key[0]
is_decoding = graph_key[1]
if graph_key not in self._runner_map:
- max_batches = max_tokens if is_decoding else self.max_batches
+ max_tokens = self._get_max_tokens(graph_key)
runner = CUDASingleGraphRunner(self.model,
max_batches=max_batches,
max_tokens=max_tokens,
diff --git a/lmdeploy/pytorch/check_env/transformers.py b/lmdeploy/pytorch/check_env/transformers.py
index defd43303b..20102b1deb 100644
--- a/lmdeploy/pytorch/check_env/transformers.py
+++ b/lmdeploy/pytorch/check_env/transformers.py
@@ -4,7 +4,7 @@
from .base import BaseChecker
MIN_TRANSFORMERS_VERSION = '4.33.0'
-MAX_TRANSFORMERS_VERSION = '4.53.3'
+MAX_TRANSFORMERS_VERSION = '4.56.1'
class TransformersChecker(BaseChecker):
diff --git a/lmdeploy/pytorch/config.py b/lmdeploy/pytorch/config.py
index 05c716c6a9..ac3459e045 100644
--- a/lmdeploy/pytorch/config.py
+++ b/lmdeploy/pytorch/config.py
@@ -1,4 +1,5 @@
# Copyright (c) OpenMMLab. All rights reserved.
+import enum
from dataclasses import dataclass
from typing import Any, Dict, List, Literal
@@ -27,7 +28,10 @@ def _update_torch_dtype(config: 'ModelConfig', dtype: str):
config.dtype = torch.float16
return config
- torch_dtype = getattr(config.hf_config, 'torch_dtype', None)
+ torch_dtype = getattr(config.hf_config, 'dtype', None)
+ if torch_dtype is None:
+ torch_dtype = getattr(config.hf_config, 'torch_dtype', None)
+
# deal with case when torch_dtype is not string but torch.dtype
if isinstance(torch_dtype, torch.dtype):
torch_dtype = str(torch_dtype).split('.')[1]
@@ -200,6 +204,9 @@ class ModelConfig:
cogvlm_style: bool = False
custom_module_map: Dict[str, setattr] = None
use_flash_mla: bool = False
+ model_paradigm: str = 'ar'
+ dllm_mask_token: int = 0
+ dllm_block_length: int = None
def get_head_size(self):
"""Get head size."""
@@ -285,6 +292,38 @@ def from_hf_config(cls,
return model_config
+class UnmaskingStrategy(enum.Enum):
+ """Unmasking Strategy."""
+
+ # unmasking from left to right
+ SEQUENTIAL = enum.auto()
+ # unmasking with confidence threshold
+ LOW_CONFIDENCE_DYNAMIC = enum.auto()
+ # unmasking with topk in a block
+ LOW_CONFIDENCE_STATIC = enum.auto()
+
+ @classmethod
+ def from_str(cls, strategy: str):
+ """From string."""
+ strategy = strategy.lower()
+ if strategy == 'sequential':
+ return cls.SEQUENTIAL
+ elif strategy == 'low_confidence_dynamic':
+ return cls.LOW_CONFIDENCE_DYNAMIC
+ elif strategy == 'low_confidence_static':
+ return cls.LOW_CONFIDENCE_STATIC
+ else:
+ raise ValueError(f'Unknown unmasking strategy: {strategy}')
+
+
+@dataclass
+class DLLMConfig:
+ block_length: int = 1
+ unmasking_strategy: UnmaskingStrategy = UnmaskingStrategy.LOW_CONFIDENCE_DYNAMIC
+ denoising_steps: int = None
+ confidence_threshold: float = 0.85
+
+
@dataclass
class MiscConfig:
prefill_interval: int = 16
@@ -294,15 +333,22 @@ class MiscConfig:
hf_overrides: Dict[str, Any] = None
disable_vision_encoder: bool = False
logprobs_mode: str = None
+ dllm_config: DLLMConfig = None
@classmethod
def from_engine_config(cls, engine_config: PytorchEngineConfig):
"""From engine config."""
+ dllm_unmasking_strategy = UnmaskingStrategy.from_str(engine_config.dllm_unmasking_strategy)
+ dllm_config = DLLMConfig(block_length=engine_config.dllm_block_length,
+ unmasking_strategy=dllm_unmasking_strategy,
+ denoising_steps=engine_config.dllm_denoising_steps,
+ confidence_threshold=engine_config.dllm_confidence_threshold)
misc_config = cls(custom_module_map=engine_config.custom_module_map,
empty_init=engine_config.empty_init,
prefill_interval=engine_config.prefill_interval,
model_format=engine_config.model_format,
hf_overrides=engine_config.hf_overrides,
disable_vision_encoder=engine_config.disable_vision_encoder,
- logprobs_mode=engine_config.logprobs_mode)
+ logprobs_mode=engine_config.logprobs_mode,
+ dllm_config=dllm_config)
return misc_config
diff --git a/lmdeploy/pytorch/configurations/sdar.py b/lmdeploy/pytorch/configurations/sdar.py
new file mode 100644
index 0000000000..edf9cd3cad
--- /dev/null
+++ b/lmdeploy/pytorch/configurations/sdar.py
@@ -0,0 +1,18 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+from .default import AutoModelConfigBuilder, DefaultModelConfigBuilder
+
+
+class SDARModelConfigBuilder(AutoModelConfigBuilder):
+
+ @classmethod
+ def condition(cls, hf_config):
+ """config."""
+ return hf_config.model_type in ['sdar', 'sdar_moe']
+
+ @classmethod
+ def build(cls, hf_config, model_path: str = None, **kwargs):
+ """build."""
+ cfg = DefaultModelConfigBuilder.build(hf_config, model_path, **kwargs)
+ cfg.dllm_mask_token = 151669
+ cfg.model_paradigm = 'dllm'
+ return cfg
diff --git a/lmdeploy/pytorch/consts.py b/lmdeploy/pytorch/consts.py
new file mode 100644
index 0000000000..93dc2f4d71
--- /dev/null
+++ b/lmdeploy/pytorch/consts.py
@@ -0,0 +1,5 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+# dllm
+DLLM_MASKED = 0
+DLLM_UNMASKED = 1
+DLLM_CACHED = 2
diff --git a/lmdeploy/pytorch/disagg/backend/__init__.py b/lmdeploy/pytorch/disagg/backend/__init__.py
index 3ab02d3bd6..2309e16c02 100644
--- a/lmdeploy/pytorch/disagg/backend/__init__.py
+++ b/lmdeploy/pytorch/disagg/backend/__init__.py
@@ -7,7 +7,7 @@
logger.debug('Registering DLSlime Backend')
from .dlslime import DLSlimeBackend
except ImportError:
- logger.warning('Disable DLSlime Backend')
+ logger.debug('Disable DLSlime Backend')
try:
logger.debug('Registering Mooncake Backend')
diff --git a/lmdeploy/pytorch/engine/engine.py b/lmdeploy/pytorch/engine/engine.py
index f94c6c9a46..2179b5a99f 100644
--- a/lmdeploy/pytorch/engine/engine.py
+++ b/lmdeploy/pytorch/engine/engine.py
@@ -20,13 +20,13 @@
from ..adapter.adapter import AdapterManager
from ..config import BackendConfig, CacheConfig, DistConfig, MiscConfig, ModelConfig, SchedulerConfig
-from ..messages import MessageStatus, SchedulerSequence
+from ..messages import MessageStatus, SchedulerSequence, UpdateTokenMode
from ..model_inputs import ModelInputs, VisionModelInputs
from ..paging import Scheduler
+from ..strategies import build_strategy_factory
from .base import EngineBase
from .engine_checker import EngineChecker
from .executor import build_executor
-from .logits_process import SamplingInputs
from .model_agent import BatchedOutputs
from .request import Request, RequestManager, RequestType, Response
@@ -81,6 +81,15 @@ def _update_engine_config(engine_config: PytorchEngineConfig):
if engine_config.max_batch_size is None:
engine_config.max_batch_size = get_max_batch_size(engine_config.device_type)
+ if engine_config.dllm_block_length is not None:
+ max_prefill_token_num = engine_config.max_prefill_token_num
+ max_batch_size = engine_config.max_batch_size
+ if max_batch_size * engine_config.dllm_block_length > max_prefill_token_num:
+ engine_config.max_batch_size = max_prefill_token_num // engine_config.dllm_block_length
+ logger.warning(f'Update max_batch_size to {engine_config.max_batch_size} '
+ f'since dllm_block_length({engine_config.dllm_block_length}) * max_batch_size '
+ f'({max_batch_size}) > max_prefill_token_num ({max_prefill_token_num}).')
+
if engine_config.dp != 1:
if engine_config.tp == 1 and engine_config.ep == 1:
engine_config.dp = 1
@@ -139,6 +148,13 @@ def _build_misc_config(engine_config: PytorchEngineConfig):
return misc_config
+def _build_seq_meta(cache_config: CacheConfig, strategy: Any):
+ from lmdeploy.pytorch.messages import SequenceMeta
+
+ seq_meta = SequenceMeta(cache_config.block_size, strategy=strategy)
+ return seq_meta
+
+
class CounterEvent:
def __init__(self):
@@ -366,10 +382,19 @@ def __init__(self,
dtype=engine_config.dtype)
self.executor.init()
+ # strategies
+ self.strategy_factory = build_strategy_factory(self.model_config, self.executor.misc_config)
+ self.sampling_strategy = self.strategy_factory.build_sampling_strategy()
+ self.model_agent_strategy = self.strategy_factory.build_model_agent_strategy()
+ self.engine_strategy = self.strategy_factory.build_engine_strategy(cache_config=cache_config,
+ scheduler_config=scheduler_config)
+ self.seq_strategy = self.strategy_factory.build_sequence_strategy()
+
self.input_processor = self.executor.get_input_processor()
cache_config = self.executor.cache_config
self.adapter_manager = self._build_adapter_manager(adapters)
- self.scheduler = Scheduler(scheduler_config, cache_config)
+ self.seq_meta = _build_seq_meta(cache_config, strategy=self.seq_strategy)
+ self.scheduler = Scheduler(scheduler_config, cache_config, seq_meta=self.seq_meta)
# engine args
self.model_path = model_path
@@ -378,6 +403,7 @@ def __init__(self,
self.cache_config = cache_config
self.backend_config = backend_config
self.dist_config = dist_config
+ self.misc_config = self.executor.misc_config
self.max_session_len = self._get_max_session_len()
self.engine_config.num_cpu_blocks = self.cache_config.num_cpu_blocks
self.engine_config.num_gpu_blocks = self.cache_config.num_gpu_blocks
@@ -575,7 +601,7 @@ def __update_max_new_tokens(msg):
max_session_len = self.max_session_len
sampling_param = msg.sampling_param
max_new_tokens = sampling_param.max_new_tokens
- num_all_tokens = msg.num_all_tokens()
+ num_all_tokens = msg.num_valid_ids
if max_new_tokens + num_all_tokens > max_session_len:
logger.warning(
f'session[{msg.session_id}]: num tokens is larger than max session len {max_session_len}. '
@@ -592,14 +618,12 @@ def __update_max_new_tokens(msg):
sess = scheduler.sessions[session_id]
# TODO: support 1 session n sequence
sampling_param = req.data['sampling_param']
- return_logits = sampling_param.out_logits
if len(sess.sequences) == 0:
migration_request = req.data.get('migration_request')
assert len(req.data['token_ids']) > 0, ('Empty input is not allowed.')
sess.add_sequence(req.data['token_ids'],
sampling_param=sampling_param,
adapter_name=req.data['adapter_name'],
- return_logits=return_logits,
multimodals=req.data.get('input_multimodals'),
input_embeddings=req.data.get('input_embeddings', ),
migration_request=migration_request,
@@ -617,11 +641,9 @@ def __update_max_new_tokens(msg):
req.data['token_ids'],
multimodals=req.data.get('input_multimodals'),
embeddings=req.data.get('input_embeddings'),
- append_tokens=True,
+ mode=UpdateTokenMode.INPUTS,
)
- msg.num_new_tokens = 0
msg.sampling_param = sampling_param
- msg.return_logits = return_logits
msg.status = MessageStatus.WAITING
__update_max_new_tokens(msg)
@@ -659,10 +681,11 @@ def __get_vlm_embeddings():
]
input_embedding_indexing = torch.zeros((batch_size, max_q_seq_length), dtype=torch.bool)
for msg_id, msg in enumerate(messages):
+ num_history_ids = msg.num_history_ids
for emb in msg.input_embeddings:
# make slice index relative to embeddings
- emb_start = emb.start - msg.history_len
- emb_end = emb.end - msg.history_len
+ emb_start = emb.start - num_history_ids
+ emb_end = emb.end - num_history_ids
input_embedding_indexing[msg_id][emb_start:emb_end] = True
return (input_embeddings, input_embedding_indexing, input_embedding_ranges)
@@ -716,7 +739,7 @@ def create_model_inputs(self, messages: SeqList, is_prefill: bool):
"""
batch_size = len(messages)
# history lengths
- history_lengths = torch.tensor([msg.history_len for msg in messages])
+ history_lengths = torch.tensor([msg.num_history_ids for msg in messages])
# input ids
token_ids = [msg.token_ids for msg in messages]
@@ -729,8 +752,8 @@ def create_model_inputs(self, messages: SeqList, is_prefill: bool):
seq_length = torch.tensor(seq_length, dtype=torch.long)
max_q_seqlen = seq_length.max().item()
else:
- seq_length = torch.ones(batch_size, dtype=torch.long)
- max_q_seqlen = 1
+ max_q_seqlen = len(token_ids[0])
+ seq_length = torch.full((batch_size, ), max_q_seqlen, dtype=torch.long)
kv_seqlens = seq_length + history_lengths
max_kv_seqlen = kv_seqlens.max().item()
sum_kv_seqlen = kv_seqlens.sum().item()
@@ -780,23 +803,6 @@ def create_model_inputs(self, messages: SeqList, is_prefill: bool):
return model_inputs
- def update_running(self, running: SeqList, next_token_ids: torch.Tensor, stopped: torch.Tensor,
- model_metas: List[Dict[str, Any]]):
- """Update scheduler."""
- if model_metas is None:
- model_metas = [None] * len(running)
- next_token_ids = next_token_ids.numpy()
- for token, msg, stop, model_meta in zip(next_token_ids, running, stopped, model_metas):
- if msg.status != MessageStatus.LOCKED:
- continue
- update_token = token
-
- # fill token
- msg.update_token_ids(update_token, model_meta=model_meta)
- msg.num_new_tokens += 1
- if stop:
- msg.status = MessageStatus.TO_BE_MIGRATED if msg.preserve_cache else MessageStatus.STOPPED
-
def update_running_migration(self, running: SeqList, next_token_ids: np.ndarray, stopped: torch.Tensor,
model_metas: List[Dict[str, Any]]):
"""Update scheduler."""
@@ -808,40 +814,40 @@ def update_running_migration(self, running: SeqList, next_token_ids: np.ndarray,
update_token = token
# fill token
- msg.update_token_ids(update_token, model_meta=model_meta)
- msg.num_new_tokens += 1
+ msg.update_token_ids(update_token, model_meta=model_meta, mode=UpdateTokenMode.PREFILL)
if stop:
update_token = _EMPTY_TOKEN
- msg.update_token_ids(update_token, model_meta=model_meta)
+ msg.update_token_ids(update_token, model_meta=model_meta, mode=UpdateTokenMode.PREFILL)
msg.status = MessageStatus.STOPPED
def _make_infer_outputs(
self,
batched_outputs: BatchedOutputs,
running: SeqList,
+ is_decoding: bool,
):
"""Make infer output."""
new_token_timestamp = batched_outputs.new_token_timestamp
- next_token_ids = batched_outputs.next_token_ids
logits = batched_outputs.logits
- stopped = batched_outputs.stopped
- model_metas = batched_outputs.model_metas
logprobs = batched_outputs.logprobs
seq_length = [seq.num_token_ids for seq in running]
is_run = [seq.status == MessageStatus.LOCKED for seq in running]
- stopped = stopped.tolist()
- self.update_running(running, next_token_ids, stopped, model_metas)
+ self.seq_strategy.update_running(running=running, batched_outputs=batched_outputs, is_decoding=is_decoding)
# generate output
outputs: Dict[int, InferOutput] = dict()
for idx, msg in enumerate(running):
if not is_run[idx]:
continue
- token_ids = msg.all_ids[-msg.num_new_tokens:]
+ token_ids = msg.generated_ids
finish = msg.status == MessageStatus.STOPPED or msg.status == MessageStatus.TO_BE_MIGRATED
if not finish and len(token_ids) == 0:
continue
+ resp_data = msg.resp.data
+ if resp_data is not None and len(resp_data.get('token_ids', [])) == len(token_ids):
+ # no new tokens
+ continue
session_id = msg.session_id
if msg.resp_cache:
cache_block_ids = self.scheduler.block_manager.get_block_table(msg).tolist()
@@ -870,51 +876,6 @@ def _make_infer_outputs(
def _make_forward_inputs(self, prefill: bool, enable_empty: bool = False):
"""Make forward inputs."""
- prefill_interval = self.scheduler_config.prefill_interval
-
- def __gather_all_ids(seqs: SeqList, sampling_inputs: SamplingInputs):
- """Gather history."""
- if sampling_inputs.repetition_penalty is None and not any(sampling_inputs.logits_processors):
- return None
- batch = len(seqs)
- max_len = max(seq.num_all_ids for seq in seqs)
- pad_id = self.model_config.bos_token_id
- pad_id = 0 if pad_id is None else pad_id
- output = torch.full((batch, max_len), pad_id, dtype=torch.int64)
- for idx, seq in enumerate(seqs):
- h_len = seq.num_all_ids
- if h_len == 0:
- continue
- h_ids = torch.from_numpy(seq.all_ids)
- output[idx, -h_len:] = h_ids
- return output
-
- def __gather_guided_input_ids(seqs: SeqList, sampling_inputs: SamplingInputs):
- """Gather input ids for guided decode."""
- if not any(sampling_inputs.response_formats or ()):
- return None
- batch = len(seqs)
- max_len = max(seq.num_new_tokens for seq in seqs)
- pad_id = self.model_config.bos_token_id
- pad_id = 0 if pad_id is None else pad_id
- output = torch.full((batch, max_len), pad_id, dtype=torch.int64)
- for idx, seq in enumerate(seqs):
- h_len = seq.num_new_tokens
- if h_len == 0:
- continue
- h_ids = torch.from_numpy(seq.all_ids[-seq.num_new_tokens:])
- output[idx, -h_len:] = h_ids
- return output
-
- def __get_num_appendable_ids(seqs: SeqList):
- """Get num appendable ids."""
- ret = [seq.sampling_param.max_new_tokens - seq.num_new_tokens for seq in seqs]
- return torch.tensor(ret)
-
- def __get_num_ignore_eos(seqs: SeqList):
- """Get num ignore eos."""
- ret = [seq.sampling_param.min_new_tokens - seq.num_new_tokens for seq in seqs]
- return torch.tensor(ret)
def __need_logits(seqs: SeqList):
"""Need logits."""
@@ -923,7 +884,8 @@ def __need_logits(seqs: SeqList):
scheduler = self.scheduler
logger.debug(f'Make forward inputs with prefill={prefill}, enable_empty={enable_empty}')
- scheduler_output = scheduler.schedule(is_prefill=prefill, prealloc_size=prefill_interval)
+ prealloc_size = self.engine_strategy.get_prealloc_size(not prefill)
+ scheduler_output = scheduler.schedule(is_prefill=prefill, prealloc_size=prealloc_size)
if enable_empty and len(scheduler_output.running) == 0:
return None
@@ -931,9 +893,10 @@ def __need_logits(seqs: SeqList):
# schedule decoding if no valid prefill reqs.
if prefill and len(scheduler_output.running) == 0 and self.engine_config.role != EngineRole.Prefill:
prefill = False
- scheduler_output = scheduler.schedule(is_prefill=prefill, prealloc_size=prefill_interval)
+ prealloc_size = self.engine_strategy.get_prealloc_size(not prefill)
+ scheduler_output = scheduler.schedule(is_prefill=prefill, prealloc_size=prealloc_size)
- num_loops = 1 if prefill else prefill_interval
+ num_loops = self.engine_strategy.get_num_loops(not prefill)
running = scheduler_output.running
swap_in_map = scheduler_output.swap_in_map
swap_out_map = scheduler_output.swap_out_map
@@ -943,12 +906,10 @@ def __need_logits(seqs: SeqList):
# create inputs
inputs = self.create_model_inputs(running, prefill)
- sampling_inputs = SamplingInputs.from_sampling_params(running)
- all_ids = __gather_all_ids(running, sampling_inputs)
- guided_input_ids = __gather_guided_input_ids(running, sampling_inputs)
- num_appendable_ids = __get_num_appendable_ids(running)
- num_ignore_eos = __get_num_ignore_eos(running)
+ sampling_inputs = self.sampling_strategy.make_sampling_inputs(running)
return_logits = __need_logits(running)
+ extra_inputs = self.model_agent_strategy.make_extra_inputs(running)
+ stopping_criteria = self.model_agent_strategy.make_stopping_criteria(running)
sync_long_context = inputs.input_ids.numel() > self.cache_config.max_prefill_token_num
return dict(
@@ -957,14 +918,12 @@ def __need_logits(seqs: SeqList):
swap_in_map=swap_in_map,
swap_out_map=swap_out_map,
loop_count=num_loops,
- all_ids=all_ids,
- guided_input_ids=guided_input_ids,
sampling_inputs=sampling_inputs,
- num_appendable_ids=num_appendable_ids,
- num_ignore_eos=num_ignore_eos,
+ stopping_criteria=stopping_criteria,
return_logits=return_logits,
is_dummy=False,
sync_long_context=sync_long_context,
+ extra_inputs=extra_inputs,
)
async def _await_forward_event(self, forward_event: asyncio.Event):
@@ -1132,6 +1091,7 @@ async def _async_loop_main(
forward_event.set()
num_loops = forward_inputs['loop_count']
+ is_decoding = forward_inputs['inputs'].is_decoding
running = next_running
next_running = None
scheduler.lock_running(running)
@@ -1145,7 +1105,7 @@ async def _async_loop_main(
# send output
out = await self.executor.get_output_async()
if out is not None:
- step_outputs = self._make_infer_outputs(out, running=running)
+ step_outputs = self._make_infer_outputs(out, running=running, is_decoding=is_decoding)
resp_que.put_nowait(step_outputs)
# lock forward event
diff --git a/lmdeploy/pytorch/engine/executor/base.py b/lmdeploy/pytorch/engine/executor/base.py
index 70a906df0f..9e50843a80 100644
--- a/lmdeploy/pytorch/engine/executor/base.py
+++ b/lmdeploy/pytorch/engine/executor/base.py
@@ -35,7 +35,7 @@ def __init__(self,
self.cache_config = cache_config
self.backend_config = backend_config
self.dist_config = dist_config
- self.misc_config = misc_config,
+ self.misc_config = misc_config
self.dp = dist_config.dp
self.tp = dist_config.tp
self.world_size = dist_config.world_size
diff --git a/lmdeploy/pytorch/engine/logits_process.py b/lmdeploy/pytorch/engine/logits_process.py
index 214f83256e..b30fbb3992 100644
--- a/lmdeploy/pytorch/engine/logits_process.py
+++ b/lmdeploy/pytorch/engine/logits_process.py
@@ -109,6 +109,9 @@ def _guided_sampling(response_formats: Tuple[Dict], scores: torch.Tensor, guided
return scores
+SeqList = List[SchedulerSequence]
+
+
@dataclass
class SamplingInputs:
temperature: torch.Tensor = None
@@ -120,142 +123,17 @@ class SamplingInputs:
top_k: torch.LongTensor = None
top_p: torch.Tensor = None
min_p: torch.Tensor = None
- random_seeds: int = None
- random_offsets: int = None
+ random_seeds: torch.Tensor = None
+ random_offsets: torch.Tensor = None
max_top_k: int = 1
min_top_p: float = 1.0
response_formats: Tuple[str] = ()
logits_processors: List[List[LogitsProcessor]] = None
max_num_logprobs: Optional[int] = None
-
- @classmethod
- def from_sampling_params(cls, seqs: List[SchedulerSequence]):
- """From samplingg params."""
- batch_size = len(seqs)
- temperature = [None] * batch_size
- repetition_penalty = [None] * batch_size
- top_k = [None] * batch_size
- top_p = [None] * batch_size
- min_p = [None] * batch_size
- bad_words = [None] * batch_size
- stop_words = [None] * batch_size
- random_seeds = [torch.seed() & 0xffffffff] * batch_size
- random_offsets = [None] * batch_size
- response_formats = [None] * batch_size
- logits_processors = [None] * batch_size
- num_logprobs = [None] * batch_size
-
- def __gather_params():
- """Gather params."""
- for idx, seq in enumerate(seqs):
- param = seq.sampling_param
- temperature[idx] = param.temperature
- repetition_penalty[idx] = param.repetition_penalty
- top_k[idx] = param.top_k
- top_p[idx] = param.top_p
- min_p[idx] = param.min_p
- random_offsets[idx] = seq.random_offsets
- response_formats[idx] = param.response_format
- if param.random_seed is not None:
- random_seeds[idx] = param.random_seed & 0xffffffff
-
- bw = param.bad_words
- sw = param.stop_words
- if (not param.ignore_eos and seq.num_new_tokens < param.min_new_tokens):
- bw = bw + sw
- bad_words[idx] = bw
- stop_words[idx] = sw
- logits_processors[idx] = param.logits_processors
- num_logprobs[idx] = param.num_logprobs
-
- def __get_topp(top_p):
- """Get topp."""
- min_top_p = min(top_p)
- if min_top_p == 1.0:
- top_p = None
- else:
- top_p = torch.tensor(top_p)
- return top_p, min_top_p
-
- def __get_minp(min_p):
- """Get minp."""
- max_min_p = max(min_p)
- if max_min_p == 0.0:
- min_p = None
- else:
- min_p = torch.Tensor(min_p)
- return min_p
-
- def __get_bad_words(bad_words):
- """Get bad words."""
- max_bw_len = max(len(bw) for bw in bad_words)
- if max_bw_len == 0:
- return None, None
- if all(len(bw) == max_bw_len for bw in bad_words):
- ret = torch.tensor(bad_words)
- mask = torch.ones_like(ret, dtype=bool)
- return ret, mask
- ret = torch.full((batch_size, max_bw_len), -1, dtype=torch.int64)
- for idx, bw in enumerate(bad_words):
- bw_len = len(bw)
- if bw_len == 0:
- continue
- bw = ret.new_tensor(bw)
- ret[idx, :bw_len] = bw
-
- mask = ret >= 0
- ret = ret.where(mask, 0)
- return ret, mask
-
- __gather_params()
-
- if all(rp == 1.0 for rp in repetition_penalty):
- repetition_penalty = None
- else:
- repetition_penalty = torch.tensor(repetition_penalty)
-
- temperature = torch.tensor(temperature)
-
- bad_words, bad_mask = __get_bad_words(bad_words)
- stop_words, stop_mask = __get_bad_words(stop_words)
-
- max_top_k = max(top_k)
- if min(top_k) <= 0:
- max_top_k = 0
- if max_top_k == 1:
- top_k = None
- top_p, min_top_p = None, 1.0
- min_p = None
- random_seeds = None
- random_offsets = None
- else:
- top_k = torch.tensor(top_k)
- top_p, min_top_p = __get_topp(top_p)
- min_p = __get_minp(min_p)
- random_seeds = torch.tensor(random_seeds)
- random_offsets = torch.tensor(random_offsets)
-
- max_num_logprobs = max(num_logprobs)
-
- sampling_input = cls(
- temperature=temperature,
- bad_words=bad_words,
- bad_mask=bad_mask,
- stop_words=stop_words,
- stop_mask=stop_mask,
- repetition_penalty=repetition_penalty,
- top_k=top_k,
- top_p=top_p,
- min_p=min_p,
- random_seeds=random_seeds,
- random_offsets=random_offsets,
- response_formats=tuple(response_formats),
- max_top_k=max_top_k,
- min_top_p=min_top_p,
- logits_processors=logits_processors,
- max_num_logprobs=max_num_logprobs,
- )
- return sampling_input
+ all_ids: Optional[torch.Tensor] = None
+ guided_input_ids: Optional[torch.Tensor] = None
+ num_ignore_eos: torch.Tensor = None
+ batch_size: int = 0
def to_device(self, device: str, non_blocking: bool = False):
"""To device."""
@@ -284,12 +162,10 @@ class FusedLogitsProcessor:
def __init__(self,
sampling_inputs: SamplingInputs,
- ignore_eos: torch.Tensor,
tokenizer: Optional[Tokenizer] = None,
sampling_vocab_size: Optional[int] = None,
logprobs_mode: Optional[str] = None):
self.sampling_inputs: SamplingInputs = sampling_inputs
- self.ignore_eos = ignore_eos
self.tokenizer = tokenizer
self.sampling_vocab_size = sampling_vocab_size
self.logprobs_mode = logprobs_mode
@@ -300,12 +176,9 @@ async def _wait_stream_once(self):
if not stream.query():
await asyncio.sleep(0)
- async def __call__(self, all_ids: torch.LongTensor, guided_input_ids: torch.LongTensor,
- scores: torch.FloatTensor) -> torch.FloatTensor:
+ async def __call__(self, scores: torch.FloatTensor) -> torch.FloatTensor:
r"""
Args:
- all_ids (torch.LongTensor): All the token ids.
- guided_input_ids (torch.LongTensor): Guided prompt ids.
scores (torch.FloatTensor):
Prediction scores of a language modeling head.
These can be logits for each vocabulary when not using
@@ -331,6 +204,8 @@ async def __call__(self, all_ids: torch.LongTensor, guided_input_ids: torch.Long
logprobs = None
sampling_inputs = self.sampling_inputs
+ all_ids = sampling_inputs.all_ids
+ guided_input_ids = sampling_inputs.guided_input_ids
custom_logits_processors = self.sampling_inputs.logits_processors
if any(custom_logits_processors):
@@ -352,8 +227,9 @@ async def __call__(self, all_ids: torch.LongTensor, guided_input_ids: torch.Long
stop_words = sampling_inputs.stop_words
if stop_words is not None:
+ ignore_eos = sampling_inputs.num_ignore_eos > 0
stop_mask = sampling_inputs.stop_mask
- stop_mask = torch.where(self.ignore_eos[:, None], stop_mask, False)
+ stop_mask = torch.where(ignore_eos[:, None], stop_mask, False)
scores = _process_bad_words_(scores, stop_words, stop_mask)
if guided_input_ids is not None:
diff --git a/lmdeploy/pytorch/engine/model_agent.py b/lmdeploy/pytorch/engine/model_agent.py
index 3457e5efe9..7b332b6f0c 100644
--- a/lmdeploy/pytorch/engine/model_agent.py
+++ b/lmdeploy/pytorch/engine/model_agent.py
@@ -25,6 +25,8 @@
from ..distributed import DistContext, get_dist_manager
from ..model_inputs import ModelInputs, step_ctx_manager
from ..models.patch import BuildModelContext, add_adapters, build_patched_model, update_custom_module_map
+from ..strategies import build_strategy_factory
+from ..strategies.base.model_agent import ExtraInputs, ExtraOutputs, StoppingCriteria
from ..utils import get_gpu_memory
from ..weight_loader.model_weight_loader import load_model_weights
from .cache_engine import CacheEngine
@@ -63,10 +65,12 @@ def to_tensor(self):
class BatchedOutputs:
next_token_ids: torch.Tensor
stopped: torch.Tensor
+ stop_pos: Optional[torch.Tensor] = None
logits: Optional[torch.Tensor] = None
model_metas: List[Dict[str, Any]] = None
logprobs: Optional[BatchedLogProbs] = None
new_token_timestamp: int = 0
+ extra_outputs: Optional[ExtraOutputs] = None
def to_cpu(self):
"""To cpu."""
@@ -199,6 +203,8 @@ def msg_with_rank(rank: int, msg: str):
def cache_swapping(cache_engine: CacheEngine, swap_in_map: dict, swap_out_map: dict):
"""Perform cache swapping."""
issued_cache_op = False
+ swap_in_map = swap_in_map or dict()
+ swap_out_map = swap_out_map or dict()
if len(swap_in_map) > 0:
cache_engine.swap_in(swap_in_map)
issued_cache_op = True
@@ -242,24 +248,11 @@ def model_forward(
# InternVL-3.5-Flash will change the seqlen, model_metas during forward
model_metas = context.model_metas
- seq_length = context.q_seqlens
+ seq_length = context.q_seqlens[:len(inputs.seq_length)]
return dict(hidden_states=output, model_metas=model_metas, seq_length=seq_length)
-@record_function('stopping_criteria')
-def _batch_stopping_criteria(token_ids: torch.Tensor, stop_words: torch.Tensor, num_appendable_ids: torch.Tensor):
- """Batched stopping criteria."""
- num_appendable_ids = num_appendable_ids - 1
- stopped = num_appendable_ids <= 0
- if stop_words is not None:
- sw_stopped = (token_ids[:, None] == stop_words).any(1)
- stopped = stopped | sw_stopped
- one_ids = torch.clamp_max(num_appendable_ids, 0)
- num_appendable_ids = torch.where(sw_stopped, one_ids, num_appendable_ids)
- return stopped, num_appendable_ids
-
-
def _try_to_cuda(val, non_blocking: bool = False):
if val is None:
return val
@@ -369,6 +362,11 @@ def __init__(self,
self.enable_microbatch_decode_batchsize_threshold = \
int(getenv('ENABLE_MICROBATCH_DECODE_BATCHSIZE_THRESHOLD', 2))
+ # strategy
+ self.strategy_factory = build_strategy_factory(model_config, misc_config)
+ self.inputs_strategy = self.strategy_factory.build_model_inputs_strategy()
+ self.agent_strategy = self.strategy_factory.build_model_agent_strategy()
+
@contextmanager
def all_context(self):
device_mgr = get_device_manager()
@@ -399,33 +397,42 @@ def warmup(self):
num_tokens = max_batches
# warmup prefill
- inputs = ModelInputs.make_dummy(max_batches,
- is_decoding=False,
- device='cuda',
- vocab_size=self.model_config.vocab_size)
- self._forward_impl(inputs, swap_in_map=dict(), swap_out_map=dict())
+ inputs = self.inputs_strategy.make_dummy(max_batches,
+ is_decoding=False,
+ device='cuda',
+ vocab_size=self.model_config.vocab_size)
+ self._forward_impl(inputs)
# warmup decoding(with cuda graph)
capture_batch_sizes = self.patched_model.get_capture_batch_sizes()
capture_batch_sizes = sorted(capture_batch_sizes, reverse=True)
for num_tokens in capture_batch_sizes:
- inputs = ModelInputs.make_dummy(num_tokens,
- is_decoding=True,
- device='cuda',
- vocab_size=self.model_config.vocab_size)
- self._forward_impl(inputs, swap_in_map=dict(), swap_out_map=dict())
+ inputs = self.inputs_strategy.make_dummy(num_tokens,
+ is_decoding=True,
+ device='cuda',
+ vocab_size=self.model_config.vocab_size)
+ self._forward_impl(inputs)
+
+ def _slice_outs(self, inputs: torch.Tensor, seq_length: torch.LongTensor):
+ """Slice outputs."""
+ return self.agent_strategy.slice_outputs(inputs, seq_length)
+
+ def _postprocess_forward_output(self, output: dict, inputs: ModelInputs):
+ """Post process forward output."""
+ hidden_states = output['hidden_states']
+ seq_length = output.get('seq_length', inputs.seq_length)
+ hidden_states = self._slice_outs(hidden_states[0], seq_length)[None]
+ output['hidden_states'] = hidden_states
+ return output
async def _async_model_forward(
self,
inputs: ModelInputs,
- swap_in_map: Dict,
- swap_out_map: Dict,
return_logits: bool,
sync_long_context: bool,
):
"""Model forward."""
max_prefill_token_num = self.cache_config.max_prefill_token_num
- swap_done = False
class _OutputGather:
"""Output gather."""
@@ -433,7 +440,8 @@ class _OutputGather:
def __init__(self, max_seq_len):
self._max_seq_len = max_seq_len
self._start = 0
- self._output = None
+ self._output: torch.Tensor = None
+ self._device: torch.device = None
def gather(self, output):
"""gather."""
@@ -448,6 +456,7 @@ def gather(self, output):
seq_len = tmp_output.size(-2)
if out_logits is None:
out_logits = tmp_output.new_empty(1, self._max_seq_len, tmp_output.size(-1), device='cpu')
+ self._device = tmp_output.device
out_logits[:, start:start + seq_len].copy_(tmp_output, non_blocking=True)
self._start = start + seq_len
self._output = out_logits
@@ -457,16 +466,9 @@ def get_output(self):
if not return_logits:
return self._output[:, -1:]
torch.cuda.synchronize()
- return self._output
+ return self._output.to(self._device)
- async def __forward(inputs):
- """forward."""
- nonlocal swap_done, swap_in_map, swap_out_map
- if swap_done:
- return await self.async_forward(inputs, swap_in_map=dict(), swap_out_map=dict())
- else:
- swap_done = True
- return await self.async_forward(inputs, swap_in_map=swap_in_map, swap_out_map=swap_out_map)
+ __forward = self.async_forward
async def __long_context_single_forward(new_inputs, max_seqlen: int):
"""One large sequence."""
@@ -485,6 +487,8 @@ async def __long_context_single_forward(new_inputs, max_seqlen: int):
tmp_out['hidden_states'] = output_gather.get_output()
return tmp_out
+ origin_inputs = inputs
+
# make long context inputs
is_long_context = inputs.input_ids.numel() > max_prefill_token_num and not inputs.is_decoding
max_seqlen = 0
@@ -507,24 +511,15 @@ async def __long_context_single_forward(new_inputs, max_seqlen: int):
if not is_long_context:
ret = await __forward(inputs)
- if not return_logits and not inputs.is_decoding:
- # fetch seq_length from the context, since models may change it (e.g. InternVL-3.5-Flash)
- seq_length = ret.get('seq_length', None)
- assert seq_length is not None, 'seq_length cannot be None.'
- last_token_loc = seq_length.cumsum(0) - 1
-
- ret['hidden_states'] = ret['hidden_states'][:, last_token_loc]
else:
ret = await __long_context_single_forward(inputs, max_seqlen)
- if not return_logits:
- last_token_loc = [-1]
- ret['hidden_states'] = ret['hidden_states'][:, last_token_loc]
- else:
- ret['hidden_states'] = ret['hidden_states'].to('cuda')
+
+ if not return_logits:
+ ret = self._postprocess_forward_output(ret, origin_inputs)
# compute dummy loop
if dummy_loop > 0:
- dummy_inputs = ModelInputs.make_dummy(1, False, 'cuda', vocab_size=self.model_config.vocab_size)
+ dummy_inputs = self.inputs_strategy.make_dummy(1, False, 'cuda', vocab_size=self.model_config.vocab_size)
for _ in range(dummy_loop):
await __forward(dummy_inputs)
@@ -533,29 +528,18 @@ async def __long_context_single_forward(new_inputs, max_seqlen: int):
ret['logits'] = logits
return ret
- async def async_sampling_logits(self, logits: torch.Tensor, all_ids: torch.Tensor, guided_input_ids: torch.Tensor,
- sampling_inputs: SamplingInputs, inputs: ModelInputs, ignore_eos: torch.Tensor):
+ async def async_sampling_logits(self, logits: torch.Tensor, sampling_inputs: SamplingInputs, inputs: ModelInputs):
"""Sampling logits."""
- def __get_last_logits():
- """Get last logits."""
- seq_length = inputs.seq_length
- if len(seq_length) == logits.size(0):
- return logits
-
- last_idx = seq_length.cumsum(-1) - 1
- return logits[last_idx, :]
-
# record function does not support async function
# so we can not decorate it on async_sampling_logits
with record_function('sampling_logits'):
- split_logits = __get_last_logits()
logits_processor = FusedLogitsProcessor(sampling_inputs,
- ignore_eos,
self.tokenizer,
sampling_vocab_size=self.sampling_vocab_size,
logprobs_mode=self.misc_config.logprobs_mode)
- logits, raw_logprobs = await logits_processor(all_ids, guided_input_ids, split_logits)
+ origin_logits = logits
+ logits, raw_logprobs = await logits_processor(origin_logits)
next_token_ids = logits_processor.sampling(logits)
logprobs = logits_processor.compute_logprobs(raw_logprobs, next_token_ids)
if logprobs is not None:
@@ -572,17 +556,18 @@ def _push_output(self, output: BatchedOutputs):
event.record()
self._out_que.put_nowait((output, event))
- def _broadcast_next_token(self, next_token_ids: torch.Tensor, dist_ctx: DistContext = None):
+ @contextmanager
+ def _broadcast_next_token(self, next_token_ids: torch.Tensor, dist_ctx: DistContext = None, enable: bool = True):
+ if not enable:
+ yield
+ return
+
if dist_ctx is None:
dist_ctx = get_dist_manager().current_context()
- if self.cache_config.role == EngineRole.Decode:
- next_token_ids = next_token_ids.cpu()
- tp_cpu_group = dist_ctx.tp_cpu_group
- dist.all_reduce(next_token_ids, op=dist.ReduceOp.SUM, group=tp_cpu_group)
- else:
- tp_gpu_group = dist_ctx.tp_gpu_group
- dist.broadcast(next_token_ids, src=0, group=tp_gpu_group)
- return next_token_ids
+ tp_gpu_group = dist_ctx.tp_gpu_group
+ handle = dist.broadcast(next_token_ids, src=0, group=tp_gpu_group, async_op=True)
+ yield
+ handle.wait()
async def _async_step_background(
self,
@@ -590,38 +575,26 @@ async def _async_step_background(
loop_count: int,
swap_in_map: Dict = None,
swap_out_map: Dict = None,
- all_ids: torch.Tensor = None,
- guided_input_ids: torch.Tensor = None,
sampling_inputs: SamplingInputs = None,
- num_appendable_ids: torch.LongTensor = None,
- num_ignore_eos: torch.LongTensor = None,
+ stopping_criteria: StoppingCriteria = None,
return_logits: bool = False,
is_dummy: bool = False,
sync_long_context: bool = False,
+ extra_inputs: ExtraInputs = None,
):
"""Asyc forward task."""
- if swap_in_map is None:
- swap_in_map = dict()
-
- if swap_out_map is None:
- swap_out_map = dict()
-
dist_ctx = get_dist_manager().current_context()
@record_function('update_inputs_for_next_step')
- def __update_inputs(next_token_ids, model_metas):
+ def __update_inputs(next_token_ids, model_metas, extra_inputs):
"""Update inputs."""
- nonlocal all_ids, guided_input_ids, swap_in_map, swap_out_map
- swap_in_map = dict()
- swap_out_map = dict()
- inputs.model_metas = model_metas
- inputs.update(next_token_ids)
- if all_ids is not None:
- all_ids = torch.cat([all_ids, next_token_ids[:, None].to(all_ids.device)], 1)
- if guided_input_ids is not None:
- guided_input_ids = torch.cat([guided_input_ids, next_token_ids[:, None].to(guided_input_ids.device)], 1)
- if sampling_inputs.random_offsets is not None:
- sampling_inputs.random_offsets += 1
+ return self.agent_strategy.update_inputs_for_next_step(
+ inputs,
+ sampling_inputs,
+ next_token_ids=next_token_ids,
+ model_metas=model_metas,
+ extra_inputs=extra_inputs,
+ )
@asynccontextmanager
async def __prepare_dp():
@@ -705,61 +678,78 @@ async def __prepare_dp():
logger.debug(f' rank[{rank}]: all inputs are dummy, skip forward.')
return
+ cache_swapping(self.cache_engine, swap_in_map=swap_in_map, swap_out_map=swap_out_map)
for idx in range(loop_count):
# inference
logger.debug(f' rank[{rank}]: model forward [{idx}].')
output = await self._async_model_forward(
inputs,
- swap_in_map=swap_in_map,
- swap_out_map=swap_out_map,
return_logits=return_logits,
sync_long_context=sync_long_context,
)
logits = output['logits']
logits = logits[0] # [bs, seq, prob] -> [seq, prob]
+ seq_length = inputs.seq_length
+ seq_length = output.get('seq_length', inputs.seq_length)
+ last_logits = self._slice_outs(logits, seq_length) # [bs, 1, prob] -> [bs, prob]
+ extra_inputs = self.agent_strategy.slice_extra_inputs(extra_inputs, seq_length)
# output empty for dummy inputs
if is_dummy:
continue
+ need_broadcast_next = (dp == 1 and tp > 1 and idx < loop_count - 1)
+
# sampling and stopping
if need_output:
logger.debug(f' rank[{rank}]: Sampling [{idx}].')
# sampling
- next_token_ids, logprobs = await self.async_sampling_logits(logits, all_ids, guided_input_ids,
- sampling_inputs, inputs, num_ignore_eos > 0)
- num_ignore_eos = num_ignore_eos - 1
+ next_token_ids, logprobs = await self.async_sampling_logits(last_logits, sampling_inputs, inputs)
+
+ with self._broadcast_next_token(next_token_ids, dist_ctx, enable=need_broadcast_next):
+ logger.debug(f' rank[{rank}]: synchronize token ids [{idx}]')
+
+ # post sampling
+ next_token_ids, extra_inputs = self.agent_strategy.post_sampling(
+ inputs, last_logits, next_token_ids, extra_inputs)
- # stopping criteria
- stopped, num_appendable_ids = _batch_stopping_criteria(next_token_ids, sampling_inputs.stop_words,
- num_appendable_ids)
+ # stopping criteria
+ stopped, stop_pos, stopping_criteria = stopping_criteria.step(next_token_ids,
+ sampling_inputs.stop_words,
+ inputs=inputs,
+ extra_inputs=extra_inputs)
else:
# Avoid adding the ADInplaceOrView dispatch key to `next_token_ids`,
# as it can trigger recompilation on different ranks when using torch.compile.
with torch.inference_mode():
- next_token_ids = torch.zeros_like(num_ignore_eos)
+ next_token_ids = inputs.input_ids.new_zeros(last_logits.size(0))
logprobs = None
- # broadcast next token for TP > 1
- need_broadcast_next = (dp == 1 and tp > 1 and idx < loop_count - 1)
- if need_broadcast_next:
- logger.debug(f' rank[{rank}]: synchornize token ids [{idx}]')
- next_token_ids = self._broadcast_next_token(next_token_ids, dist_ctx)
+ # broadcast next token for TP > 1
+ with self._broadcast_next_token(next_token_ids, dist_ctx, enable=need_broadcast_next):
+ logger.debug(f' rank[{rank}]: synchronize token ids [{idx}]')
+
+ # post sampling
+ next_token_ids, extra_inputs = self.agent_strategy.post_sampling(inputs, last_logits, next_token_ids,
+ extra_inputs)
# send output
model_metas = output.get('model_metas')
if need_output:
logger.debug(f' rank[{rank}]: Output [{idx}]')
+ extra_outputs = self.agent_strategy.make_extra_outputs(extra_inputs)
self._push_output(
BatchedOutputs(next_token_ids=next_token_ids,
logits=logits if return_logits else None,
stopped=stopped,
+ stop_pos=stop_pos,
model_metas=model_metas,
- logprobs=logprobs))
+ logprobs=logprobs,
+ extra_outputs=extra_outputs))
# update for next loop
if is_decoding and idx < loop_count - 1:
- __update_inputs(next_token_ids, model_metas)
+ inputs, extra_inputs = __update_inputs(next_token_ids, model_metas, extra_inputs)
async def _async_loop_background(self, forward_event: asyncio.Event = None):
"""Async loop background."""
@@ -787,7 +777,7 @@ async def _async_loop_background(self, forward_event: asyncio.Event = None):
async def _async_loop_inputs_preprocess(self):
"""Async loop inputs preprocess."""
non_blocking = True
- keys = ['inputs', 'all_ids', 'guided_input_ids', 'sampling_inputs', 'num_appendable_ids', 'num_ignore_eos']
+ keys = ['inputs', 'sampling_inputs', 'stopping_criteria', 'extra_inputs']
while True:
forward_inputs = await self._pre_in_que.get()
@@ -923,7 +913,9 @@ def _build_model(self):
if custom_module_map is not None:
update_custom_module_map(custom_module_map)
logger.debug(msg_with_rank(rank, 'build model.'))
- build_model_ctx = BuildModelContext(disable_vision_encoder=self.misc_config.disable_vision_encoder)
+ build_model_ctx = BuildModelContext(disable_vision_encoder=self.misc_config.disable_vision_encoder,
+ dllm_config=self.misc_config.dllm_config,
+ strategy_factory=self.strategy_factory)
patched_model = build_patched_model(self.model_config,
device=device,
model_format=self.misc_config.model_format,
@@ -935,6 +927,7 @@ def _build_model(self):
logger.debug(msg_with_rank(rank, 'loading adapters.'))
add_adapters(patched_model, adapters, dtype=self.model_config.dtype, device=device)
self.patched_model = patched_model
+ self.build_model_ctx = build_model_ctx
def build_model(self):
"""Build model api."""
@@ -965,8 +958,7 @@ def build_cache_engine(self):
world_size=tp,
cache_stream=self.cache_stream)
- def _forward_impl(self, inputs: ModelInputs, swap_in_map: SwapMap, swap_out_map: SwapMap):
- cache_swapping(self.cache_engine, swap_in_map=swap_in_map, swap_out_map=swap_out_map)
+ def _forward_impl(self, inputs: ModelInputs):
output = model_forward(
self.patched_model,
inputs,
@@ -975,7 +967,7 @@ def _forward_impl(self, inputs: ModelInputs, swap_in_map: SwapMap, swap_out_map:
)
return output
- async def async_forward(self, inputs: ModelInputs, swap_in_map: SwapMap, swap_out_map: SwapMap):
+ async def async_forward(self, inputs: ModelInputs):
"""Model forward.
Args:
@@ -983,7 +975,7 @@ async def async_forward(self, inputs: ModelInputs, swap_in_map: SwapMap, swap_ou
swap_in_map (SwapMap): Cache maps to swap in.
swap_out_map (SwapMap): Cache maps to swap out.
"""
- output = self._forward_impl(inputs, swap_in_map=swap_in_map, swap_out_map=swap_out_map)
+ output = self._forward_impl(inputs)
await asyncio.sleep(0)
return output
@@ -1091,6 +1083,7 @@ def __init__(self, model_agent: BaseModelAgent):
self.model_config = model_agent.model_config
self.cache_config = model_agent.cache_config
self.misc_config = model_agent.misc_config
+ self.inputs_strategy = model_agent.inputs_strategy
self.device = model_agent.device
self._in_que = model_agent._in_que
@@ -1106,10 +1099,10 @@ def _make_dummy_forward_inputs(self):
dist_config = self.dist_ctx.dist_config
batch_size = 2 if dist_config.enable_microbatch else 1
batch_size = min(self.cache_config.max_batches, batch_size)
- model_inputs = ModelInputs.make_dummy(batch_size,
- is_decoding,
- device=self.device,
- vocab_size=self.model_config.vocab_size)
+ model_inputs = self.inputs_strategy.make_dummy(batch_size,
+ is_decoding,
+ device=self.device,
+ vocab_size=self.model_config.vocab_size)
forward_inputs = dict(
inputs=model_inputs,
loop_count=loop_count,
diff --git a/lmdeploy/pytorch/kernels/cuda/flashattention.py b/lmdeploy/pytorch/kernels/cuda/flashattention.py
index a22f3d36c0..8b53b84aa2 100644
--- a/lmdeploy/pytorch/kernels/cuda/flashattention.py
+++ b/lmdeploy/pytorch/kernels/cuda/flashattention.py
@@ -56,7 +56,8 @@ def _load_kv(ptrs, boundary_check: tl.constexpr):
def _prefill_fwd_inner(acc, l_i, m_i, q, k_ptrs, v_ptrs, q1, k1_ptrs, loop_start, loop_end, sm_scale, history_mask,
kv_min_loc, causal_mask: tl.constexpr, window_size: tl.constexpr,
logit_softcapping: tl.constexpr, k_bound: tl.constexpr, v_bound: tl.constexpr,
- shared_kv: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_DK1: tl.constexpr):
+ shared_kv: tl.constexpr, block_sparse_size: tl.constexpr, BLOCK_N: tl.constexpr,
+ BLOCK_DK1: tl.constexpr):
k_ptrs = tl.advance(k_ptrs, (0, loop_start))
v_ptrs = tl.advance(v_ptrs, (loop_start, 0))
if BLOCK_DK1:
@@ -77,7 +78,11 @@ def _prefill_fwd_inner(acc, l_i, m_i, q, k_ptrs, v_ptrs, q1, k1_ptrs, loop_start
qk *= sm_scale
qk = softcapping(qk, logit_softcapping)
qk = qk * tl_log2(math.e)
- qk_mask = (history_mask[:, None]) >= (start_n + offs_n[None, :])
+ if block_sparse_size > 1:
+ offs_mask = (start_n + offs_n) // block_sparse_size * block_sparse_size
+ qk_mask = (history_mask[:, None]) >= offs_mask[None, :]
+ else:
+ qk_mask = (history_mask[:, None]) >= (start_n + offs_n[None, :])
if window_size > 0:
qk_mask = qk_mask & ((start_n + offs_n[None, :]) >= kv_min_loc[:, None])
qk = tl.where(
@@ -180,6 +185,7 @@ def _flash_prefill_fwd_kernel(
window_size: tl.constexpr,
logit_softcapping: tl.constexpr,
shared_kv: tl.constexpr,
+ block_sparse_size: tl.constexpr,
BLOCK_M: tl.constexpr,
BLOCK_N: tl.constexpr,
BLOCK_DK: tl.constexpr,
@@ -295,6 +301,7 @@ def _flash_prefill_fwd_kernel(
k_bound=k_bound0,
v_bound=v_bound0,
shared_kv=shared_kv,
+ block_sparse_size=block_sparse_size,
BLOCK_N=BLOCK_N,
BLOCK_DK1=BLOCK_DK1)
@@ -322,6 +329,7 @@ def _flash_prefill_fwd_kernel(
k_bound=k_bound1,
v_bound=v_bound1,
shared_kv=shared_kv,
+ block_sparse_size=block_sparse_size,
BLOCK_N=BLOCK_N,
BLOCK_DK1=BLOCK_DK1)
# epilogue
@@ -448,6 +456,7 @@ def flash_attention_fwd(
logit_softcapping: float = None,
sinks: Tensor = None,
causal: bool = True,
+ block_sparse_size: int = 1,
kv_layout: str = 'hsd',
):
"""Varlen flash Attention forward.
@@ -546,6 +555,7 @@ def grid(args):
window_size=window_size,
logit_softcapping=logit_softcapping,
shared_kv=shared_kv,
+ block_sparse_size=block_sparse_size,
BLOCK_DK=BLOCK_DK,
BLOCK_DK1=BLOCK_DK1,
BLOCK_DV=BLOCK_DV,
diff --git a/lmdeploy/pytorch/kernels/cuda/pagedattention.py b/lmdeploy/pytorch/kernels/cuda/pagedattention.py
index a6097e7027..a4941c36af 100644
--- a/lmdeploy/pytorch/kernels/cuda/pagedattention.py
+++ b/lmdeploy/pytorch/kernels/cuda/pagedattention.py
@@ -59,6 +59,7 @@ def _fwd_grouped_split_kernel(
stride_od: tl.constexpr,
stride_boffb,
kv_group_num: tl.constexpr,
+ seq_len: tl.constexpr,
window_size: tl.constexpr,
head_size: tl.constexpr,
head_size_v: tl.constexpr,
@@ -74,18 +75,20 @@ def _fwd_grouped_split_kernel(
):
"""First step kernel of split k attention."""
cur_batch = tl.program_id(2)
- cur_kv_head = tl.program_id(0)
+ tile_id = tl.program_id(0)
split_k_id = tl.program_id(1)
- if BLOCK_H < kv_group_num:
- HEAD_PER_CTA: tl.constexpr = BLOCK_H
- else:
- HEAD_PER_CTA: tl.constexpr = kv_group_num
- cur_head = cur_kv_head * HEAD_PER_CTA + tl.arange(0, BLOCK_H)
- mask_h = cur_head < cur_kv_head * HEAD_PER_CTA + HEAD_PER_CTA
+ HEADS_PER_REQ: tl.constexpr = kv_group_num * seq_len
+ TILES_PER_GROUP: tl.constexpr = tl.cdiv(HEADS_PER_REQ, BLOCK_H)
+ subtile_id = tile_id % TILES_PER_GROUP
+ cur_kv_head = tile_id // TILES_PER_GROUP
+ offs_h = subtile_id * BLOCK_H + tl.arange(0, BLOCK_H)
+ cur_head = cur_kv_head * kv_group_num + offs_h % kv_group_num
+ cur_token = cur_batch * seq_len + offs_h // kv_group_num
+
+ mask_h = cur_head < cur_kv_head * kv_group_num + kv_group_num
+ mask_h = mask_h & (cur_token < cur_batch * seq_len + seq_len)
mask_h = mask_h & (cur_head < num_heads_q)
- if BLOCK_H < kv_group_num:
- cur_kv_head = (cur_kv_head * HEAD_PER_CTA) // kv_group_num
q_seqlen = 1
kv_seqlen = tl.load(KV_seqlens + cur_batch)
@@ -104,7 +107,7 @@ def _fwd_grouped_split_kernel(
off_k = (cur_kv_head * stride_kh + offs_d[:, None] * stride_kd + offs_n[None, :] * stride_kbs)
off_v = (cur_kv_head * stride_vh + offs_dv[None, :] * stride_vd + offs_n[:, None] * stride_vbs)
- off_q = (cur_batch * stride_qbs + cur_head[:, None] * stride_qh + offs_d[None, :] * stride_qd)
+ off_q = (cur_token[:, None] * stride_qbs + cur_head[:, None] * stride_qh + offs_d[None, :] * stride_qd)
q = tl.load(Q + off_q, mask=mask_h[:, None] & mask_d[None, :], other=0)
k_ptrs = K + off_k
@@ -114,7 +117,7 @@ def _fwd_grouped_split_kernel(
offs_d1 = BLOCK_DMODEL + tl.arange(0, BLOCK_DMODEL1)
mask_d1 = offs_d1 < head_size
offs_d1 = offs_d1 % head_size
- off_q1 = (cur_batch * stride_qbs + cur_head[:, None] * stride_qh + offs_d1[None, :] * stride_qd)
+ off_q1 = (cur_token[:, None] * stride_qbs + cur_head[:, None] * stride_qh + offs_d1[None, :] * stride_qd)
q1 = tl.load(Q + off_q1, mask=mask_h[:, None] & mask_d1[None, :], other=0)
off_k1 = (cur_kv_head * stride_kh + offs_d1[:, None] * stride_kd + offs_n[None, :] * stride_kbs)
k1_ptrs = K + off_k1
@@ -196,11 +199,11 @@ def _fwd_grouped_split_kernel(
# initialize pointers to output
if loop_end > loop_start:
- off_acc = (cur_batch * stride_obs + split_k_id * stride_ok + cur_head[:, None] * stride_oh +
+ off_acc = (cur_token[:, None] * stride_obs + split_k_id * stride_ok + cur_head[:, None] * stride_oh +
offs_dv[None, :] * stride_od)
tl.store(Acc_out + off_acc, acc, mask=mask_h[:, None] & mask_dv[None, :])
- off_meta = (cur_batch * stride_obs + split_k_id * stride_ok + cur_head * stride_oh + head_size_v)
+ off_meta = (cur_token * stride_obs + split_k_id * stride_ok + cur_head * stride_oh + head_size_v)
tl.store(Acc_out + off_meta, m_i, mask=mask_h)
tl.store(Acc_out + off_meta + 1, l_i, mask=mask_h)
@@ -588,7 +591,9 @@ def _get_block_d(Lk):
if sm_scale is None:
sm_scale = 1.0 / (Lq**0.5)
batch, head = kv_seqlens.shape[0], q.shape[-2]
- kv_group_num = q.shape[-2] // k.shape[h_dim]
+ num_tokens = q.shape[-3]
+ num_kv_heads = k.shape[h_dim]
+ kv_group_num = head // num_kv_heads
if sinks is not None:
assert sinks.is_contiguous()
@@ -601,20 +606,22 @@ def _get_block_d(Lk):
'might leads to bad performance. '
'Please reduce `block_size`.')
- is_decoding = q.shape[-3] == kv_seqlens.size(0)
- assert is_decoding, 'we only support decoding paged attention.'
+ valid = num_tokens % batch == 0
+ assert valid, 'we only support decoding paged attention.'
+ seq_len = num_tokens // batch
BLOCK_DMODEL, BLOCK_DMODEL1, BLOCK_DV = _get_block_d(Lq)
- p2_kv_group_num = triton.next_power_of_2(kv_group_num)
- BLOCK_H = max(16, min(BLOCK, p2_kv_group_num))
- grid_1 = triton.cdiv(head, min(BLOCK_H, kv_group_num))
+ HEADS_PER_REQ = kv_group_num * seq_len
+ BLOCK_H = max(16, min(BLOCK, triton.next_power_of_2(HEADS_PER_REQ)))
+ TILES_PER_GROUP = triton.cdiv(HEADS_PER_REQ, BLOCK_H)
+ grid_1 = TILES_PER_GROUP * num_kv_heads
SPLIT_K = _get_split_k(q.device.index, grid_1, batch)
if quant_policy != 4:
- acc = q.new_empty(batch, head, SPLIT_K, Lv + 2, dtype=torch.float32)
+ acc = q.new_empty(num_tokens, head, SPLIT_K, Lv + 2, dtype=torch.float32)
else:
- acc = q.new_empty(batch, head, SPLIT_K, o.shape[-1] + 2, dtype=torch.float32)
+ acc = q.new_empty(num_tokens, head, SPLIT_K, o.shape[-1] + 2, dtype=torch.float32)
grid = (
grid_1,
@@ -704,6 +711,7 @@ def _get_block_d(Lk):
stride_od=acc.stride(-1),
stride_boffb=block_offsets.stride(0),
kv_group_num=kv_group_num,
+ seq_len=seq_len,
window_size=window_size,
head_size=Lk,
head_size_v=Lv,
@@ -720,7 +728,7 @@ def _get_block_d(Lk):
num_stages=num_stages)
num_warps = 4
- grid = (batch, head)
+ grid = (num_tokens, head)
if quant_policy == 4:
Lv *= 2
BLOCK_DV *= 2
diff --git a/lmdeploy/pytorch/messages.py b/lmdeploy/pytorch/messages.py
index 4cc559a20b..101ed62546 100644
--- a/lmdeploy/pytorch/messages.py
+++ b/lmdeploy/pytorch/messages.py
@@ -1,8 +1,7 @@
# Copyright (c) OpenMMLab. All rights reserved.
import enum
-import time
from dataclasses import dataclass, field
-from typing import Any, Dict, List, Optional
+from typing import TYPE_CHECKING, Any, Dict, List, Optional
import numpy as np
from torch import Tensor
@@ -14,6 +13,9 @@
from .block import LogicalTokenBlocks
+if TYPE_CHECKING:
+ from lmdeploy.pytorch.strategies.base.sequence import SequenceStrategy
+
logger = get_logger('lmdeploy')
# vlm input type from pipeline
@@ -56,7 +58,7 @@ class SamplingParam:
num_logprobs: int = -1
@classmethod
- def from_gen_config(self, gen_config: GenerationConfig):
+ def from_gen_config(cls, gen_config: GenerationConfig):
"""From gen config."""
min_new_tokens = gen_config.min_new_tokens or 0
@@ -158,29 +160,33 @@ class MessageStatus(enum.Enum):
MIGRATION_DONE = enum.auto()
-_SEQ_COUNT = 0
-
-
-def _new_msg_id():
- """Get a new message id."""
- global _SEQ_COUNT
- seq_id = _SEQ_COUNT
- _SEQ_COUNT += 1
- return seq_id
+SeqMap = Dict[int, 'SchedulerSequence']
-SeqMap = Dict[int, 'SchedulerSequence']
+@dataclass
+class SequenceMeta:
+ """Meta data shared by all sequence."""
+ block_size: int
+ strategy: 'SequenceStrategy' = None
class SequenceManager:
"""Sequence manager."""
- def __init__(self) -> None:
+ def __init__(self, seq_meta: SequenceMeta) -> None:
self._seq_map: SeqMap = dict()
self._status_seq_map: Dict[MessageStatus, SeqMap] = dict()
for status in MessageStatus:
self._status_seq_map[status] = dict()
+ self.seq_meta = seq_meta
+ self._seq_count = 0
+
+ def _new_seq_id(self):
+ seq_id = self._seq_count
+ self._seq_count += 1
+ return seq_id
+
def get_all_sequences(self):
"""Get all sequences."""
return self._seq_map.values()
@@ -223,12 +229,23 @@ def update_sequence_status(self, seq: 'SchedulerSequence', new_status: MessageSt
new_status_map[seq_id] = seq
+def _to_ndarray(token_ids) -> np.ndarray:
+ """To ndarray."""
+ if isinstance(token_ids, Tensor):
+ token_ids = token_ids.numpy()
+ elif not isinstance(token_ids, np.ndarray):
+ token_ids = np.array(token_ids)
+ if token_ids.ndim == 0:
+ token_ids = token_ids[None]
+ return token_ids
+
+
class SchedulerSession:
"""Scheduler session."""
- def __init__(self, session_id: int, block_size: int, seq_manager: SequenceManager = None) -> None:
+ def __init__(self, session_id: int, seq_manager: SequenceManager) -> None:
self.session_id = session_id
- self.block_size = block_size
+ self.seq_meta = seq_manager.seq_meta
self.status: MessageStatus = MessageStatus.RUNNING
self.sequences: SeqMap = dict()
self.seq_manager = seq_manager
@@ -237,48 +254,38 @@ def add_sequence(self,
token_ids: Tensor,
sampling_param: SamplingParam = None,
adapter_name: str = None,
- return_logits: bool = False,
multimodals: MultiModalInputs = None,
input_embeddings: List[InputEmbeddings] = None,
migration_request: Optional[MigrationRequest] = None,
resp_cache: bool = False,
preserve_cache: bool = False) -> 'SchedulerSequence':
"""Add a new message."""
- if isinstance(token_ids, Tensor):
- token_ids = token_ids.numpy()
- elif not isinstance(token_ids, np.ndarray):
- token_ids = np.array(token_ids)
- if token_ids.ndim == 0:
- token_ids = token_ids.unsqueeze(0)
if sampling_param is None:
sampling_param = SamplingParam()
- seq = SchedulerSequence(
- seq_id=_new_msg_id(),
- session=self,
- history_cache=HistoryTokenIds(token_ids),
- num_new_tokens=0,
- sampling_param=sampling_param,
- adapter_name=adapter_name,
- arrive_time=time.perf_counter(),
- history_embeddings=HistoryEmbeddings(input_embeddings),
- history_multimodals=HistoryMultiModals(multimodals),
- return_logits=return_logits,
- migration_request=migration_request,
- resp_cache=resp_cache,
- preserve_cache=preserve_cache,
+ seq_id = self.seq_manager._new_seq_id()
+ seq = self.seq_meta.strategy.make_sequence(seq_id=seq_id,
+ session=self,
+ sampling_param=sampling_param,
+ adapter_name=adapter_name,
+ migration_request=migration_request,
+ resp_cache=resp_cache,
+ preserve_cache=preserve_cache)
+ seq.update_token_ids(
+ token_ids,
+ multimodals=multimodals,
+ embeddings=input_embeddings,
+ mode=UpdateTokenMode.INPUTS,
)
self.sequences[seq.seq_id] = seq
- if self.seq_manager is not None:
- self.seq_manager.add_sequence(seq)
+ self.seq_manager.add_sequence(seq)
return seq
def remove_sequence(self, seq: 'SchedulerSequence'):
"""Remove sequence."""
assert seq.seq_id in self.sequences
self.sequences.pop(seq.seq_id)
- if self.seq_manager is not None:
- self.seq_manager.remove_sequence(seq)
+ self.seq_manager.remove_sequence(seq)
def _div_up(x, n):
@@ -342,9 +349,9 @@ class HistoryTokenIds:
"""History token ids."""
ALLOC_SIZE = 512
- def __init__(self, token_ids: np.ndarray = None):
+ def __init__(self, token_ids: np.ndarray = None, dtype: np.dtype = np.int64):
if token_ids is None:
- self._token_ids = np.empty((self.ALLOC_SIZE, ), dtype=np.int64)
+ self._token_ids = np.empty((self.ALLOC_SIZE, ), dtype=dtype)
self._num_real = 0
else:
self._token_ids = token_ids
@@ -363,6 +370,11 @@ def get_real(self):
"""Get logical blocks."""
return self._token_ids[:self._num_real]
+ def resize(self, size: int):
+ """Set size."""
+ assert size <= self._num_real
+ self._num_real = size
+
def __setitem__(self, *args, **kwargs):
"""Set values."""
return self.get_real().__setitem__(*args, **kwargs)
@@ -397,7 +409,7 @@ def copy(self):
class HistoryMultiModals:
- def __init__(self, multimodals: MultiModalInputs):
+ def __init__(self, multimodals: MultiModalInputs = None):
if multimodals is None:
multimodals = dict()
self.multimodals = multimodals
@@ -453,6 +465,13 @@ def get_encoder_len(self, start=0, end=-1):
return out_len
+class UpdateTokenMode(enum.Enum):
+ """Update token mode."""
+ INPUTS = enum.auto()
+ PREFILL = enum.auto()
+ DECODE = enum.auto()
+
+
@dataclass
class SchedulerSequence:
"""Scheduler message."""
@@ -466,9 +485,8 @@ class SchedulerSequence:
logical_blocks: LogicalTokenBlocks = field(default_factory=LogicalTokenBlocks)
adapter_name: str = None
arrive_time: float = 0.0
+ output_start_pos: int = 0
meta: Any = None
- return_logits: bool = False
- random_offsets: int = 0
_status: MessageStatus = field(default=MessageStatus.WAITING, init=False)
num_ignored_history: int = 0
model_meta: Dict[str, Any] = None
@@ -483,23 +501,20 @@ class SchedulerSequence:
def __post_init__(self):
"""Post init."""
- self._num_history_ids: int = 0
+ self._seq_meta: SequenceMeta = self.session.seq_meta
self._num_history_images: int = 0
- self._num_images: int = len(self.history_embeddings)
+ self._num_history_ids: int = 0
self._num_token_ids: int = len(self.history_cache)
+ # vlm
+ self._num_images: int = len(self.history_embeddings)
self._num_history_cross: int = 0
self._num_cross: int = self.history_multimodals.get_encoder_len(0, self._num_token_ids)
@property
def block_size(self) -> int:
"""Block size."""
- return self.session.block_size
-
- @property
- def history_len(self) -> int:
- """Get history length."""
- return self._num_history_ids
+ return self._seq_meta.block_size
@property
def history_image_num(self) -> int:
@@ -519,7 +534,7 @@ def session_id(self) -> int:
@property
def token_ids(self) -> np.ndarray:
"""Token ids."""
- start = self.history_len
+ start = self.num_history_ids
end = start + self._num_token_ids
return self.history_cache._token_ids[start:end]
@@ -533,13 +548,24 @@ def input_embeddings(self) -> List[InputEmbeddings]:
@property
def history_ids(self) -> np.ndarray:
"""History ids."""
- return self.history_cache._token_ids[:self.history_len]
+ return self.history_cache._token_ids[:self.num_history_ids]
@property
def all_ids(self) -> np.ndarray:
"""Full token ids."""
return self.history_cache._token_ids[:self.num_all_ids]
+ @property
+ def valid_ids(self) -> np.ndarray:
+ """Valid token ids."""
+ return self.history_cache._token_ids[:self.num_valid_ids]
+
+ @property
+ def generated_ids(self) -> np.ndarray:
+ end = self.num_valid_ids
+ start = end - self.num_new_tokens
+ return self.history_cache._token_ids[start:end]
+
@property
def num_history_ids(self):
"""Num history ids."""
@@ -549,6 +575,10 @@ def num_history_ids(self):
def num_token_ids(self):
return self._num_token_ids
+ @property
+ def num_valid_ids(self):
+ return self._num_history_ids + self._num_token_ids
+
@property
def num_images(self):
return self._num_images
@@ -556,7 +586,7 @@ def num_images(self):
@property
def num_all_ids(self):
"""Num all tokens."""
- return self.history_len + self._num_token_ids
+ return self._num_history_ids + self._num_token_ids
@property
def num_cross(self):
@@ -582,15 +612,15 @@ def seq_manager(self) -> SequenceManager:
def status(self):
return self._status
+ @property
+ def return_logits(self):
+ return self.sampling_param.out_logits
+
@status.setter
def status(self, value: MessageStatus):
self.seq_manager.update_sequence_status(self, value)
self._status = value
- def num_all_tokens(self):
- """Num all tokens."""
- return self.num_all_ids
-
def num_all_cross_tokens(self):
"""Num of all cross tokens."""
return self._num_cross + self._num_history_cross
@@ -601,79 +631,45 @@ def get_input_multimodals(self):
end = self.num_all_ids
return self.history_multimodals.get_datas(start, end)
- def update_token_ids(self,
- token_ids: Tensor,
- multimodals: MultiModalInputs = None,
- embeddings: List[InputEmbeddings] = None,
- model_meta: Dict[str, Any] = None,
- append_tokens: bool = False):
- """Update token ids, old token ids will be added to history."""
- old_num_history_ids = self._num_history_ids
-
- # update history
- if not append_tokens:
- self._num_history_ids += self._num_token_ids
+ def record_event(
+ self,
+ event_type: EventType,
+ timestamp: Optional[float] = None,
+ ) -> None:
+ self.engine_events.append(EngineEvent.new_event(event_type, timestamp))
- # update history image nums
+ def _update_embeddings(self, embeddings: List[InputEmbeddings]):
+ """Update input embeddings."""
self._num_history_images += self._num_images
- self._num_images = 0
- if embeddings is not None:
- new_embeddings = [emb.move_position(self._num_history_ids) for emb in embeddings]
- self._num_images = len(new_embeddings)
- self.history_embeddings.append(new_embeddings)
-
- # update multimodals
- if multimodals is not None:
- multimodals = HistoryMultiModals.update_multimodals(multimodals, self.num_all_ids)
- self.history_multimodals.add_inputs(multimodals)
+ if embeddings is None:
+ self._num_images = 0
+ return
+ new_embeddings = [emb.move_position(self._num_history_ids) for emb in embeddings]
+ self._num_images = len(new_embeddings)
+ self.history_embeddings.append(new_embeddings)
- # cross
+ def _update_multimodals(self, multimodals: MultiModalInputs):
+ """Update input multimodals."""
self._num_history_cross += self._num_cross
- if multimodals is not None:
- self._num_cross = self.history_multimodals.get_encoder_len(old_num_history_ids, self._num_history_ids)
- else:
+ if multimodals is None:
self._num_cross = 0
+ return
+ multimodals = HistoryMultiModals.update_multimodals(multimodals, self.num_valid_ids)
+ self.history_multimodals.add_inputs(multimodals)
- if model_meta is not None:
- self.model_meta = model_meta
-
- if isinstance(token_ids, Tensor):
- token_ids = token_ids.numpy()
- elif not isinstance(token_ids, np.ndarray):
- token_ids = np.array(token_ids)
- if token_ids.ndim == 0:
- token_ids = token_ids[None]
- if append_tokens:
- self._num_token_ids += len(token_ids)
- else:
- self._num_token_ids = len(token_ids)
- self.history_cache.append(token_ids)
- self.random_offsets += 1
- self.arrive_time = time.perf_counter()
+ # for mllama
+ self._num_cross = self.history_multimodals.get_encoder_len(self._num_history_ids, self._num_history_ids)
+
+ def update_token_ids(self,
+ token_ids: Tensor,
+ multimodals: MultiModalInputs = None,
+ embeddings: List[InputEmbeddings] = None,
+ model_meta: Dict[str, Any] = None,
+ mode: UpdateTokenMode = UpdateTokenMode.INPUTS,
+ **kwargs):
+ """Update token ids, old token ids will be added to history."""
+ raise NotImplementedError('NotImplemented')
def set_step(self, step: int):
"""Set step."""
- num_all_ids = self.num_all_ids
- # update step for vlm
- if len(self.history_embeddings) > 0:
- new_step, self._num_history_images, self._num_images = \
- self.history_embeddings.get_step(step)
- assert 0 <= new_step <= step
- step = new_step
- self._num_history_ids = step
- self._num_token_ids = num_all_ids - step
- self.num_ignored_history = min(step, self.num_ignored_history)
-
- self.model_meta = None
-
- # cross
- if self.history_multimodals is not None:
- self._num_history_cross = self.history_multimodals.get_encoder_len(0, self.num_history_ids)
- self._num_cross = self.history_multimodals.get_encoder_len(self._num_history_ids, num_all_ids)
-
- def record_event(
- self,
- event_type: EventType,
- timestamp: Optional[float] = None,
- ) -> None:
- self.engine_events.append(EngineEvent.new_event(event_type, timestamp))
+ raise NotImplementedError('NotImplemented')
diff --git a/lmdeploy/pytorch/model_inputs.py b/lmdeploy/pytorch/model_inputs.py
index 2afa8a0b9f..a377c9d4d6 100644
--- a/lmdeploy/pytorch/model_inputs.py
+++ b/lmdeploy/pytorch/model_inputs.py
@@ -1,7 +1,7 @@
# Copyright (c) OpenMMLab. All rights reserved.
from contextlib import contextmanager
from dataclasses import dataclass, field, fields
-from typing import Any, Dict, List, Literal
+from typing import TYPE_CHECKING, Any, Dict, List, Literal
import torch
from torch.profiler import record_function
@@ -9,9 +9,12 @@
# from torch import distributed as dist
import lmdeploy.pytorch.distributed as dist
from lmdeploy.pytorch.backends import get_backend
-from lmdeploy.pytorch.config import ModelConfig
+from lmdeploy.pytorch.config import DLLMConfig, ModelConfig
from lmdeploy.pytorch.multimodal.data_type import MultiModalTensor
+if TYPE_CHECKING:
+ from lmdeploy.pytorch.strategies.base import StrategyFactoryBase
+
@dataclass
class DPMeta:
@@ -142,12 +145,14 @@ class ModelInputs:
dp_meta: 'DPMeta' = None
enable_microbatch: bool = False
- def update(self, input_ids: torch.LongTensor):
+ def step(self, input_ids: torch.LongTensor, step_seqlens: torch.Tensor = None):
"""Update input ids."""
assert self.is_decoding
- self.history_lengths = self.history_lengths + 1
- self.max_kv_seqlen += 1
- self.sum_kv_seqlen += self.seq_length.numel()
+ if step_seqlens is None:
+ step_seqlens = self.seq_length
+ self.history_lengths += step_seqlens
+ self.max_kv_seqlen += self.max_q_seqlen
+ self.sum_kv_seqlen += self.max_kv_seqlen * self.seq_length.numel()
if input_ids.dim() == 1:
input_ids = input_ids[None, :]
self.input_ids = input_ids
@@ -279,38 +284,6 @@ def build_dp_meta(self):
"""Build dp meta."""
self.dp_meta = DPMeta.build(self.input_ids.numel())
- @classmethod
- @record_function('make_dummy_input')
- def make_dummy(cls,
- batch_size: int,
- is_decoding: bool,
- device: str = 'cpu',
- dummy_block_id: int = 0,
- vocab_size: int = 1):
- """Make dummy inputs."""
- input_ids = torch.randint(0, vocab_size, (
- 1,
- batch_size,
- ), dtype=torch.long, device=device)
- seq_length = torch.ones((batch_size, ), dtype=torch.long, device=device)
- history_lengths = torch.zeros((batch_size, ), dtype=torch.long, device=device)
- block_offsets = torch.full((batch_size, 1), dummy_block_id, dtype=torch.long, device=device)
- num_ignored_history = torch.zeros((batch_size, ), dtype=torch.long, device=device)
- local_adapter_ids = torch.zeros((batch_size, ), dtype=torch.long, device=device)
-
- return cls(
- input_ids=input_ids,
- seq_length=seq_length,
- history_lengths=history_lengths,
- block_offsets=block_offsets,
- is_decoding=is_decoding,
- num_ignored_history=num_ignored_history,
- max_q_seqlen=1,
- max_kv_seqlen=1,
- sum_kv_seqlen=batch_size,
- local_adapter_ids=local_adapter_ids,
- )
-
def log_info(self):
"""Get log info."""
ret = (f'num_tokens={self.input_ids.numel()}, batch_size={self.seq_length.numel()}'
@@ -426,18 +399,27 @@ def get_mask_and_position_ids(cls, inputs: ModelInputs):
"""Get position ids."""
q_seqlens = inputs.seq_length
history_seqlens = inputs.history_lengths
+ max_q_seqlen = inputs.max_q_seqlen
# decoding
- if inputs.is_decoding:
+ if max_q_seqlen == 1:
attention_mask = torch.ones_like(q_seqlens)[:, None]
position_ids = history_seqlens.unsqueeze(-1).clone()
position_ids = position_ids.flatten()
return attention_mask, position_ids
num_tokens = inputs.input_ids.numel()
- max_q_seqlen = inputs.max_q_seqlen
+ batch_size = inputs.seq_length.numel()
device = q_seqlens.device
+ # batch with same seqlens
+ if max_q_seqlen * batch_size == num_tokens:
+ attention_mask = None
+ ranges = torch.arange(0, max_q_seqlen, device=device)
+ position_ids = history_seqlens[:, None] + ranges[None, :]
+ position_ids = position_ids.flatten()
+ return attention_mask, position_ids
+
# get mask
mask_range = torch.arange(max_q_seqlen, device=device)[None, :]
attention_mask = (mask_range < q_seqlens[:, None]).long()
@@ -451,14 +433,24 @@ def get_mask_and_position_ids(cls, inputs: ModelInputs):
return attention_mask, position_ids_1d
+@dataclass
+class BuildModelContext:
+ """Context for building model."""
+ disable_vision_encoder: bool = False
+ dllm_config: DLLMConfig = None
+ strategy_factory: 'StrategyFactoryBase' = None
+
+
class StepContextManager:
- def __init__(self):
+ def __init__(self, build_ctx: BuildModelContext = None):
self._current_ctx = None
+ build_ctx = build_ctx or BuildModelContext()
+ self.build_ctx = build_ctx
- @staticmethod
@record_function('build_step_context')
def build_context(
+ self,
inputs: ModelInputs,
model_config: ModelConfig,
kv_caches: List = None,
diff --git a/lmdeploy/pytorch/models/module_map.py b/lmdeploy/pytorch/models/module_map.py
index 0578a0ac4c..498e2c6554 100644
--- a/lmdeploy/pytorch/models/module_map.py
+++ b/lmdeploy/pytorch/models/module_map.py
@@ -225,4 +225,10 @@
'GptOssForCausalLM': f'{LMDEPLOY_PYTORCH_MODEL_PATH}.gpt_oss.GptOssForCausalLM',
})
+# SDAR
+MODULE_MAP.update({
+ 'SDARForCausalLM': f'{LMDEPLOY_PYTORCH_MODEL_PATH}.sdar.SDARForCausalLM',
+ 'SDARMoeForCausalLM': f'{LMDEPLOY_PYTORCH_MODEL_PATH}.sdar_moe.SDARMoeForCausalLM',
+})
+
CUSTOM_MODULE_MAP = dict()
diff --git a/lmdeploy/pytorch/models/patch.py b/lmdeploy/pytorch/models/patch.py
index 97ad271d9a..86aa5ea51f 100644
--- a/lmdeploy/pytorch/models/patch.py
+++ b/lmdeploy/pytorch/models/patch.py
@@ -9,9 +9,9 @@
from typing import Any, Dict
import torch
-from attr import dataclass
from transformers.configuration_utils import PretrainedConfig
+from lmdeploy.pytorch.model_inputs import BuildModelContext, StepContextManager
from lmdeploy.utils import get_logger
from ..config import ModelConfig
@@ -188,10 +188,11 @@ def _get_model_class(config, module_map):
def build_model_from_hf_config(model_config: PretrainedConfig,
dtype: torch.dtype = None,
device: torch.device = None,
+ ctx_mgr: StepContextManager = None,
build_model_ctx: 'BuildModelContext' = None):
"""Build model from hf config."""
- from lmdeploy.pytorch.model_inputs import StepContextManager
- ctx_mgr = StepContextManager()
+ if ctx_mgr is None:
+ ctx_mgr = StepContextManager(build_model_ctx)
module_map = _get_module_map()
if device is None:
device = torch.device('cuda')
@@ -333,12 +334,6 @@ def add_adapters(model: torch.nn.Module,
return target_infos
-@dataclass
-class BuildModelContext:
- """Context for building model."""
- disable_vision_encoder: bool = False
-
-
BUILD_MODEL_CTX = BuildModelContext()
diff --git a/lmdeploy/pytorch/models/sdar.py b/lmdeploy/pytorch/models/sdar.py
new file mode 100644
index 0000000000..6a624e40e4
--- /dev/null
+++ b/lmdeploy/pytorch/models/sdar.py
@@ -0,0 +1,405 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+
+from typing import Any, Iterable, List, Optional, Tuple
+
+import torch
+from torch import nn
+from transformers.configuration_utils import PretrainedConfig
+
+from lmdeploy.pytorch.model_inputs import StepContext, StepContextManager
+from lmdeploy.pytorch.nn import ApplyRotaryEmb, Attention, RMSNorm, SiluAndMul, build_rotary_embedding_from_config
+from lmdeploy.pytorch.nn.linear import (build_down_linear, build_gateup_linear, build_o_proj, build_qkv_proj,
+ build_rowwise_linear)
+from lmdeploy.pytorch.weight_loader.model_weight_loader import load_weight
+
+from .utils.cudagraph import CudaGraphMixin
+
+
+class SDARAttention(nn.Module):
+ """attention."""
+
+ def __init__(self, config: PretrainedConfig, dtype: torch.dtype = None, device: torch.device = None):
+ super().__init__()
+ quantization_config = getattr(config, 'quantization_config', None)
+ num_heads = config.num_attention_heads
+ num_key_value_heads = config.num_key_value_heads
+ hidden_size = config.hidden_size
+ head_dim = getattr(config, 'head_dim', hidden_size // num_heads)
+ num_replicate_kv_heads = getattr(config, 'num_replicate_key_value_heads', 1)
+ # packed qkv
+ # Qwen3 uses 'config.attention_bias = False' for q/k/o projections
+ self.qkv_proj = build_qkv_proj(hidden_size,
+ num_q_heads=num_heads,
+ num_kv_heads=num_key_value_heads,
+ head_size=head_dim,
+ bias=config.attention_bias,
+ quant_config=quantization_config,
+ dtype=dtype,
+ device=device,
+ num_replicate_kv_heads=num_replicate_kv_heads)
+
+ # rotary embedding
+ self.apply_rotary_pos_emb = ApplyRotaryEmb()
+ dllm_block_length = config.dllm_block_length
+
+ # attention
+ self.attn_fwd = Attention(
+ num_heads,
+ head_dim,
+ num_kv_heads=num_key_value_heads,
+ v_head_size=head_dim,
+ sliding_window=config.sliding_window,
+ block_sparse_size=dllm_block_length,
+ )
+
+ # o_proj
+ self.o_proj = build_o_proj(num_heads * head_dim,
+ hidden_size,
+ bias=config.attention_bias,
+ quant_config=quantization_config,
+ dtype=dtype,
+ device=device,
+ is_tp=True)
+
+ # q, k norm
+ self.q_norm = RMSNorm(head_dim, config.rms_norm_eps, dtype=dtype, device=device)
+ self.k_norm = RMSNorm(head_dim, config.rms_norm_eps, dtype=dtype, device=device)
+
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ rotary_pos_emb: Tuple[torch.FloatTensor, torch.FloatTensor],
+ past_key_value: Optional[Tuple[torch.Tensor]] = None,
+ attn_metadata: Any = None,
+ ):
+ """Rewrite of LlamaAttention.forward."""
+ # qkv proj
+ qkv_states = self.qkv_proj(hidden_states)
+ # (-1, heads, head_dim)
+ qkv_states = qkv_states.flatten(0, -2)
+ query_states, key_states, value_states = self.qkv_proj.split_qkv(qkv_states)
+
+ # apply q, k norm
+ query_states = self.q_norm(query_states)
+ key_states = self.k_norm(key_states)
+
+ # apply rotary embedding
+ cos, sin = rotary_pos_emb
+ query_states, key_states = self.apply_rotary_pos_emb(
+ query_states,
+ key_states,
+ cos,
+ sin,
+ )
+ # attention
+ attn_output = self.attn_fwd(
+ query_states,
+ key_states,
+ value_states,
+ past_key_value[0],
+ past_key_value[1],
+ attn_metadata,
+ k_scales_zeros=None if len(past_key_value) == 2 else past_key_value[2],
+ v_scales_zeros=None if len(past_key_value) == 2 else past_key_value[3],
+ inplace=True,
+ )
+ attn_output = attn_output.reshape(*hidden_states.shape[:-1], -1)
+
+ # o proj
+ attn_output = self.o_proj(attn_output)
+ return attn_output
+
+
+class SDARMLP(nn.Module):
+ """mlp."""
+
+ def __init__(self, config: PretrainedConfig, dtype: torch.dtype = None, device: torch.device = None):
+ super().__init__()
+ quantization_config = getattr(config, 'quantization_config', None)
+ # gate up
+ self.gate_up_proj = build_gateup_linear(
+ config.hidden_size,
+ [config.intermediate_size, config.intermediate_size],
+ bias=False,
+ dtype=dtype,
+ device=device,
+ quant_config=quantization_config,
+ is_tp=True,
+ )
+
+ # silu and mul
+ self.act_fn = SiluAndMul(inplace=True)
+
+ # down
+ self.down_proj = build_down_linear(config.intermediate_size,
+ config.hidden_size,
+ bias=False,
+ quant_config=quantization_config,
+ dtype=dtype,
+ device=device,
+ is_tp=True)
+
+ def forward(self, x):
+ """forward."""
+ gate_up = self.gate_up_proj(x)
+ act = self.act_fn(gate_up)
+ return self.down_proj(act)
+
+
+class SDARDecoderLayer(nn.Module):
+ """Decode layer."""
+
+ def __init__(self,
+ config: PretrainedConfig,
+ layer_idx: int,
+ dtype: torch.dtype = None,
+ device: torch.device = None):
+ super().__init__()
+ self.layer_idx = layer_idx
+ quantization_config = getattr(config, 'quantization_config', None)
+
+ # build attention layer
+ self.self_attn = SDARAttention(config, dtype=dtype, device=device)
+
+ # build MLP
+ self.mlp = SDARMLP(config, dtype=dtype, device=device)
+
+ # build input layer norm
+ self.input_layernorm = RMSNorm(config.hidden_size,
+ config.rms_norm_eps,
+ quant_config=quantization_config,
+ dtype=dtype,
+ device=device)
+
+ # build attention layer norm
+ self.post_attention_layernorm = RMSNorm(config.hidden_size,
+ config.rms_norm_eps,
+ quant_config=quantization_config,
+ dtype=dtype,
+ device=device)
+
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ rotary_pos_emb: Tuple[torch.FloatTensor, torch.FloatTensor],
+ past_key_value: Optional[List[torch.FloatTensor]],
+ residual: Optional[torch.Tensor] = None,
+ attn_metadata: Any = None,
+ ):
+
+ if residual is None:
+ residual = hidden_states
+ hidden_states = self.input_layernorm(hidden_states)
+ else:
+ hidden_states, residual = self.input_layernorm(hidden_states, residual)
+
+ # Self Attention
+ hidden_states = self.self_attn(
+ hidden_states=hidden_states,
+ rotary_pos_emb=rotary_pos_emb,
+ past_key_value=past_key_value,
+ attn_metadata=attn_metadata,
+ )
+
+ # Fully Connected
+ hidden_states, residual = self.post_attention_layernorm(hidden_states, residual)
+ hidden_states = self.mlp(hidden_states)
+
+ outputs = (hidden_states, residual)
+ return outputs
+
+
+class SDARModel(nn.Module):
+ """model."""
+
+ def __init__(self, config: PretrainedConfig, dtype: torch.dtype = None, device: torch.device = None):
+ super().__init__()
+ self.padding_idx = config.pad_token_id
+ self.vocab_size = config.vocab_size
+
+ self.embed_tokens = nn.Embedding(config.vocab_size,
+ config.hidden_size,
+ self.padding_idx,
+ dtype=dtype,
+ device=device)
+
+ # build all decode layers
+ self.layers = nn.ModuleList([
+ SDARDecoderLayer(config, layer_idx, dtype=dtype, device=device)
+ for layer_idx in range(config.num_hidden_layers)
+ ])
+
+ # build norm
+ self.norm = RMSNorm(config.hidden_size, config.rms_norm_eps, dtype=dtype, device=device)
+
+ # build rotary embedding
+ self.rotary_emb = build_rotary_embedding_from_config(config)
+
+ def forward(
+ self,
+ input_ids: torch.LongTensor = None,
+ position_ids: Optional[torch.LongTensor] = None,
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
+ attn_metadata: Any = None,
+ inputs_embeds: Optional[torch.FloatTensor] = None,
+ ):
+ """Rewrite of forward."""
+
+ # token embedding
+ if inputs_embeds is None:
+ inputs_embeds = self.embed_tokens(input_ids)
+
+ hidden_states = inputs_embeds
+
+ # rotary embedding
+ cos, sin = self.rotary_emb(hidden_states, position_ids)
+ cos, sin = cos[0], sin[0]
+ rotary_pos_emb = (cos, sin)
+
+ # decoding
+ residual = None
+ for idx, decoder_layer in enumerate(self.layers):
+ past_key_value = past_key_values[idx]
+ hidden_states, residual = decoder_layer(
+ hidden_states,
+ rotary_pos_emb=rotary_pos_emb,
+ past_key_value=past_key_value,
+ residual=residual,
+ attn_metadata=attn_metadata,
+ )
+
+ # norm
+ hidden_states, _ = self.norm(hidden_states, residual)
+
+ return hidden_states
+
+ def get_input_embeddings(self):
+ """Get input embeddings."""
+ return self.embed_tokens
+
+
+class SDARForCausalLM(nn.Module, CudaGraphMixin):
+ """ModelForCausalLM."""
+
+ packed_modules_mapping = {
+ 'qkv_proj': [
+ 'q_proj',
+ 'k_proj',
+ 'v_proj',
+ ],
+ 'gate_up_proj': [
+ 'gate_proj',
+ 'up_proj',
+ ],
+ }
+
+ def __init__(self,
+ config: PretrainedConfig,
+ ctx_mgr: StepContextManager,
+ dtype: torch.dtype = None,
+ device: torch.device = None):
+ super().__init__()
+ self.config = config
+ self.ctx_mgr = ctx_mgr
+ config.dllm_block_length = ctx_mgr.build_ctx.dllm_config.block_length
+ # build model
+ self.model = SDARModel(config, dtype=dtype, device=device)
+ # build lm_head
+ self.lm_head = build_rowwise_linear(config.hidden_size,
+ config.vocab_size,
+ bias=False,
+ dtype=dtype,
+ device=device)
+
+ def forward(
+ self,
+ input_ids: torch.Tensor,
+ position_ids: torch.Tensor,
+ past_key_values: List[List[torch.Tensor]],
+ attn_metadata: Any = None,
+ inputs_embeds: torch.Tensor = None,
+ **kwargs,
+ ):
+ """Model forward, return logits."""
+ hidden_states = self.model(
+ input_ids=input_ids,
+ position_ids=position_ids,
+ past_key_values=past_key_values,
+ attn_metadata=attn_metadata,
+ inputs_embeds=inputs_embeds,
+ )
+ return hidden_states
+
+ def get_logits(self, hidden_states: torch.Tensor):
+ """Compute logits of the model output."""
+ return self.lm_head(hidden_states)
+
+ def update_weights(self):
+ """Update weights."""
+ if self.config.tie_word_embeddings:
+ self.lm_head.weight = self.model.embed_tokens.weight
+
+ def get_input_embeddings(self):
+ """Get input embeddings."""
+ return self.model.get_input_embeddings()
+
+ def prepare_inputs_for_generation(
+ self,
+ past_key_values: List[List[torch.Tensor]],
+ inputs_embeds: Optional[torch.Tensor] = None,
+ context: StepContext = None,
+ ):
+ """Prepare input."""
+ # get input_ids, position_ids and attention metadatas
+ input_ids = context.input_ids
+ position_ids = context.position_ids
+ attn_metadata = context.attn_metadata
+
+ # process vision embeddings
+ vision_embeddings = context.input_embeddings
+ vision_embedding_indexing = context.input_embedding_indexing
+ if vision_embeddings is not None and len(vision_embeddings) > 0:
+ if inputs_embeds is None:
+ inputs_embeds = self.get_input_embeddings()(input_ids)
+ inputs_embeds[:, vision_embedding_indexing, :] = vision_embeddings.to(inputs_embeds)
+
+ # inputs of forward
+ return dict(
+ input_ids=input_ids,
+ position_ids=position_ids,
+ past_key_values=past_key_values,
+ attn_metadata=attn_metadata,
+ inputs_embeds=inputs_embeds,
+ )
+
+ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
+ """Load weights."""
+ # modify from vllm
+ stacked_params_mapping = [
+ # (param_name, shard_name, shard_id)
+ ('.qkv_proj', '.q_proj', 'q'),
+ ('.qkv_proj', '.k_proj', 'k'),
+ ('.qkv_proj', '.v_proj', 'v'),
+ ('.gate_up_proj', '.gate_proj', 0),
+ ('.gate_up_proj', '.up_proj', 1),
+ ]
+
+ params_dict = dict(self.named_parameters())
+ for name, loaded_weight in weights:
+ if 'rotary_emb.inv_freq' in name:
+ continue
+ if ('rotary_emb.cos_cached' in name or 'rotary_emb.sin_cached' in name):
+ continue
+ if self.config.tie_word_embeddings and 'lm_head.weight' in name:
+ continue
+
+ for (param_name, weight_name, shard_id) in stacked_params_mapping:
+ if weight_name not in name:
+ continue
+ name = name.replace(weight_name, param_name)
+ param = params_dict[name]
+ load_weight(param, loaded_weight, shard_id=shard_id)
+ break
+ else:
+ param = params_dict[name]
+ load_weight(param, loaded_weight)
diff --git a/lmdeploy/pytorch/models/sdar_moe.py b/lmdeploy/pytorch/models/sdar_moe.py
new file mode 100644
index 0000000000..522d2aed95
--- /dev/null
+++ b/lmdeploy/pytorch/models/sdar_moe.py
@@ -0,0 +1,501 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+
+from typing import Any, Dict, Iterable, List, Optional, Tuple
+
+import torch
+from torch import nn
+from transformers.configuration_utils import PretrainedConfig
+
+from lmdeploy.pytorch.distributed import get_tp_world_rank
+from lmdeploy.pytorch.model_inputs import StepContext, StepContextManager
+from lmdeploy.pytorch.nn import ApplyRotaryEmb, Attention, RMSNorm, SiluAndMul, build_rotary_embedding_from_config
+from lmdeploy.pytorch.nn.linear import (build_down_linear, build_gateup_linear, build_o_proj, build_qkv_proj,
+ build_rowwise_linear)
+from lmdeploy.pytorch.nn.moe import SoftmaxTopK, build_fused_moe
+from lmdeploy.pytorch.weight_loader.model_weight_loader import load_weight
+
+from .utils.cudagraph import CudaGraphMixin
+
+
+class SDARMoeAttention(nn.Module):
+ """attention."""
+
+ def __init__(self, config: PretrainedConfig, dtype: torch.dtype = None, device: torch.device = None):
+ super().__init__()
+ quantization_config = getattr(config, 'quantization_config', None)
+ num_heads = config.num_attention_heads
+ num_key_value_heads = config.num_key_value_heads
+ hidden_size = config.hidden_size
+ head_dim = getattr(config, 'head_dim', hidden_size // num_heads)
+ num_replicate_kv_heads = getattr(config, 'num_replicate_key_value_heads', 1)
+ # packed qkv
+ # Qwen3 uses 'config.attention_bias = False' for q/k/o projections
+ self.qkv_proj = build_qkv_proj(hidden_size,
+ num_q_heads=num_heads,
+ num_kv_heads=num_key_value_heads,
+ head_size=head_dim,
+ bias=config.attention_bias,
+ quant_config=quantization_config,
+ dtype=dtype,
+ device=device,
+ num_replicate_kv_heads=num_replicate_kv_heads)
+
+ # rotary embedding
+ self.apply_rotary_pos_emb = ApplyRotaryEmb()
+ dllm_block_length = config.dllm_block_length
+
+ # attention
+ self.attn_fwd = Attention(
+ num_heads,
+ head_dim,
+ num_kv_heads=num_key_value_heads,
+ v_head_size=head_dim,
+ sliding_window=config.sliding_window,
+ block_sparse_size=dllm_block_length,
+ )
+
+ # o_proj
+ self.o_proj = build_o_proj(num_heads * head_dim,
+ hidden_size,
+ bias=config.attention_bias,
+ quant_config=quantization_config,
+ dtype=dtype,
+ device=device,
+ is_tp=True)
+
+ # q, k norm
+ self.q_norm = RMSNorm(head_dim, config.rms_norm_eps, dtype=dtype, device=device)
+ self.k_norm = RMSNorm(head_dim, config.rms_norm_eps, dtype=dtype, device=device)
+
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ rotary_pos_emb: Tuple[torch.FloatTensor, torch.FloatTensor],
+ past_key_value: Optional[Tuple[torch.Tensor]] = None,
+ attn_metadata: Any = None,
+ ):
+ """Rewrite of LlamaAttention.forward."""
+ # qkv proj
+ qkv_states = self.qkv_proj(hidden_states)
+ # (-1, heads, head_dim)
+ qkv_states = qkv_states.flatten(0, -2)
+ query_states, key_states, value_states = self.qkv_proj.split_qkv(qkv_states)
+
+ # apply q, k norm
+ query_states = self.q_norm(query_states)
+ key_states = self.k_norm(key_states)
+
+ # apply rotary embedding
+ cos, sin = rotary_pos_emb
+ query_states, key_states = self.apply_rotary_pos_emb(
+ query_states,
+ key_states,
+ cos,
+ sin,
+ )
+ # attention
+ attn_output = self.attn_fwd(
+ query_states,
+ key_states,
+ value_states,
+ past_key_value[0],
+ past_key_value[1],
+ attn_metadata,
+ k_scales_zeros=None if len(past_key_value) == 2 else past_key_value[2],
+ v_scales_zeros=None if len(past_key_value) == 2 else past_key_value[3],
+ inplace=True,
+ )
+ attn_output = attn_output.reshape(*hidden_states.shape[:-1], -1)
+
+ # o proj
+ attn_output = self.o_proj(attn_output)
+ return attn_output
+
+
+class SDARMoeMLP(nn.Module):
+ """mlp."""
+
+ def __init__(self,
+ config: PretrainedConfig,
+ intermediate_size: int,
+ dtype: torch.dtype = None,
+ device: torch.device = None):
+ super().__init__()
+ quantization_config = getattr(config, 'quantization_config', None)
+ # gate up
+ self.gate_up_proj = build_gateup_linear(
+ config.hidden_size,
+ [intermediate_size, intermediate_size],
+ bias=False,
+ dtype=dtype,
+ device=device,
+ quant_config=quantization_config,
+ is_tp=True,
+ )
+
+ # silu and mul
+ self.act_fn = SiluAndMul(inplace=True)
+
+ # down
+ self.down_proj = build_down_linear(intermediate_size,
+ config.hidden_size,
+ bias=False,
+ quant_config=quantization_config,
+ dtype=dtype,
+ device=device,
+ is_tp=True)
+
+ def forward(self, x):
+ """forward."""
+ gate_up = self.gate_up_proj(x)
+ act = self.act_fn(gate_up)
+ return self.down_proj(act)
+
+
+class SDARMoeSparseMoeBlock(nn.Module):
+ """SDARMoeSparseMoeBlock."""
+
+ def __init__(self,
+ config: PretrainedConfig,
+ layer_idx: int,
+ dtype: torch.dtype = None,
+ device: torch.device = None):
+ super().__init__()
+ self.config = config
+ quantization_config = getattr(config, 'quantization_config', None)
+ self.hidden_dim = config.hidden_size
+ self.ffn_dim = config.moe_intermediate_size
+ self.num_experts = config.num_experts
+ self.top_k = config.num_experts_per_tok
+ self.renormalize = config.norm_topk_prob
+
+ self.gate = build_rowwise_linear(
+ self.hidden_dim,
+ self.num_experts,
+ bias=False,
+ dtype=dtype,
+ device=device,
+ is_tp=False,
+ )
+
+ self.softmax_topk = SoftmaxTopK(self.top_k)
+
+ world_size, _ = get_tp_world_rank()
+ _all_reduce = world_size > 1
+ self.experts = build_fused_moe(
+ self.hidden_dim,
+ self.ffn_dim,
+ self.num_experts,
+ top_k=self.top_k,
+ renormalize=self.renormalize,
+ dtype=dtype,
+ device=device,
+ quant_config=quantization_config,
+ all_reduce=_all_reduce,
+ layer_idx=layer_idx,
+ )
+
+ def forward(self, hidden_states: torch.Tensor):
+ """forward."""
+ batch_size, sequence_length, hidden_dim = hidden_states.shape
+ hidden_states = hidden_states.view(-1, hidden_dim)
+ router_logits = self.gate(hidden_states)
+ topk_weights, topk_ids = self.softmax_topk(router_logits)
+ out_states = self.experts(
+ hidden_states,
+ topk_weights,
+ topk_ids,
+ )
+
+ out_states = out_states.reshape(batch_size, sequence_length, -1)
+ return out_states
+
+
+class SDARMoeDecoderLayer(nn.Module):
+ """Decode layer."""
+
+ def __init__(self,
+ config: PretrainedConfig,
+ layer_idx: int,
+ dtype: torch.dtype = None,
+ device: torch.device = None):
+ super().__init__()
+ self.layer_idx = layer_idx
+ quantization_config = getattr(config, 'quantization_config', None)
+
+ # build attention layer
+ self.self_attn = SDARMoeAttention(config, dtype=dtype, device=device)
+
+ # build MLP
+ if (layer_idx not in config.mlp_only_layers) and (config.num_experts > 0 and
+ (layer_idx + 1) % config.decoder_sparse_step == 0):
+ self.mlp = SDARMoeSparseMoeBlock(config, layer_idx, dtype=dtype, device=device)
+ else:
+ self.mlp = SDARMoeMLP(config, intermediate_size=config.intermediate_size, dtype=dtype, device=device)
+
+ # build input layer norm
+ self.input_layernorm = RMSNorm(config.hidden_size,
+ config.rms_norm_eps,
+ quant_config=quantization_config,
+ dtype=dtype,
+ device=device)
+
+ # build attention layer norm
+ self.post_attention_layernorm = RMSNorm(config.hidden_size,
+ config.rms_norm_eps,
+ quant_config=quantization_config,
+ dtype=dtype,
+ device=device)
+
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ rotary_pos_emb: Tuple[torch.FloatTensor, torch.FloatTensor],
+ past_key_value: Optional[List[torch.FloatTensor]],
+ residual: Optional[torch.Tensor] = None,
+ attn_metadata: Any = None,
+ ):
+
+ if residual is None:
+ residual = hidden_states
+ hidden_states = self.input_layernorm(hidden_states)
+ else:
+ hidden_states, residual = self.input_layernorm(hidden_states, residual)
+
+ # Self Attention
+ hidden_states = self.self_attn(
+ hidden_states=hidden_states,
+ rotary_pos_emb=rotary_pos_emb,
+ past_key_value=past_key_value,
+ attn_metadata=attn_metadata,
+ )
+
+ # Fully Connected
+ hidden_states, residual = self.post_attention_layernorm(hidden_states, residual)
+ hidden_states = self.mlp(hidden_states)
+
+ outputs = (hidden_states, residual)
+ return outputs
+
+
+class SDARMoeModel(nn.Module):
+ """SDAR model."""
+
+ def __init__(self, config: PretrainedConfig, dtype: torch.dtype = None, device: torch.device = None):
+ super().__init__()
+ self.padding_idx = config.pad_token_id
+ self.vocab_size = config.vocab_size
+
+ self.embed_tokens = nn.Embedding(config.vocab_size,
+ config.hidden_size,
+ self.padding_idx,
+ dtype=dtype,
+ device=device)
+
+ # build all decode layers
+ self.layers = nn.ModuleList([
+ SDARMoeDecoderLayer(config, layer_idx, dtype=dtype, device=device)
+ for layer_idx in range(config.num_hidden_layers)
+ ])
+
+ # build norm
+ self.norm = RMSNorm(config.hidden_size, config.rms_norm_eps, dtype=dtype, device=device)
+
+ # build rotary embedding
+ self.rotary_emb = build_rotary_embedding_from_config(config)
+
+ def forward(
+ self,
+ input_ids: torch.LongTensor = None,
+ position_ids: Optional[torch.LongTensor] = None,
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
+ attn_metadata: Any = None,
+ inputs_embeds: Optional[torch.FloatTensor] = None,
+ ):
+ """Rewrite of forward."""
+
+ # token embedding
+ if inputs_embeds is None:
+ inputs_embeds = self.embed_tokens(input_ids)
+
+ hidden_states = inputs_embeds
+
+ # rotary embedding
+ cos, sin = self.rotary_emb(hidden_states, position_ids)
+ cos, sin = cos[0], sin[0]
+ rotary_pos_emb = (cos, sin)
+
+ # decoding
+ residual = None
+ for idx, decoder_layer in enumerate(self.layers):
+ past_key_value = past_key_values[idx]
+ hidden_states, residual = decoder_layer(
+ hidden_states,
+ rotary_pos_emb=rotary_pos_emb,
+ past_key_value=past_key_value,
+ residual=residual,
+ attn_metadata=attn_metadata,
+ )
+
+ # norm
+ hidden_states, _ = self.norm(hidden_states, residual)
+
+ return hidden_states
+
+ def get_input_embeddings(self):
+ """Get input embeddings."""
+ return self.embed_tokens
+
+
+class SDARMoeForCausalLM(nn.Module, CudaGraphMixin):
+ """ModelForCausalLM."""
+
+ packed_modules_mapping = {
+ 'qkv_proj': [
+ 'q_proj',
+ 'k_proj',
+ 'v_proj',
+ ],
+ 'gate_up_proj': [
+ 'gate_proj',
+ 'up_proj',
+ ],
+ }
+
+ def __init__(self,
+ config: PretrainedConfig,
+ ctx_mgr: StepContextManager,
+ dtype: torch.dtype = None,
+ device: torch.device = None):
+ super().__init__()
+ self.config = config
+ self.ctx_mgr = ctx_mgr
+ config.dllm_block_length = ctx_mgr.build_ctx.dllm_config.block_length
+ # build model
+ self.model = SDARMoeModel(config, dtype=dtype, device=device)
+ # build lm_head
+ self.lm_head = build_rowwise_linear(config.hidden_size,
+ config.vocab_size,
+ bias=False,
+ dtype=dtype,
+ device=device)
+
+ def forward(
+ self,
+ input_ids: torch.Tensor,
+ position_ids: torch.Tensor,
+ past_key_values: List[List[torch.Tensor]],
+ attn_metadata: Any = None,
+ inputs_embeds: torch.Tensor = None,
+ **kwargs,
+ ):
+ """Model forward, return logits."""
+ hidden_states = self.model(
+ input_ids=input_ids,
+ position_ids=position_ids,
+ past_key_values=past_key_values,
+ attn_metadata=attn_metadata,
+ inputs_embeds=inputs_embeds,
+ )
+ return hidden_states
+
+ def get_logits(self, hidden_states: torch.Tensor):
+ """Compute logits of the model output."""
+ return self.lm_head(hidden_states)
+
+ def update_weights(self):
+ """Update weights."""
+ if self.config.tie_word_embeddings:
+ self.lm_head.weight = self.model.embed_tokens.weight
+
+ def get_input_embeddings(self):
+ """Get input embeddings."""
+ return self.model.get_input_embeddings()
+
+ def prepare_inputs_for_generation(
+ self,
+ past_key_values: List[List[torch.Tensor]],
+ inputs_embeds: Optional[torch.Tensor] = None,
+ context: StepContext = None,
+ ):
+ """Prepare input."""
+ # get input_ids, position_ids and attention metadatas
+ input_ids = context.input_ids
+ position_ids = context.position_ids
+ attn_metadata = context.attn_metadata
+
+ # process vision embeddings
+ vision_embeddings = context.input_embeddings
+ vision_embedding_indexing = context.input_embedding_indexing
+ if vision_embeddings is not None and len(vision_embeddings) > 0:
+ if inputs_embeds is None:
+ inputs_embeds = self.get_input_embeddings()(input_ids)
+ inputs_embeds[:, vision_embedding_indexing, :] = vision_embeddings.to(inputs_embeds)
+
+ # inputs of forward
+ return dict(
+ input_ids=input_ids,
+ position_ids=position_ids,
+ past_key_values=past_key_values,
+ attn_metadata=attn_metadata,
+ inputs_embeds=inputs_embeds,
+ )
+
+ def _load_weight_experts(self, name: str, loaded_weight: torch.Tensor, params_dict: Dict[str, nn.Parameter],
+ expert_params_mapping: List):
+ """Load weight experts."""
+ # load fused weights
+ for (param_name, weight_name, expert_id, shard_id) in expert_params_mapping:
+ if weight_name not in name:
+ continue
+ name = name.replace(weight_name, param_name)
+ param = params_dict[name]
+ load_weight(param, loaded_weight, expert_id=expert_id, shard_id=shard_id)
+ break
+ else:
+ param = params_dict[name]
+ load_weight(param, loaded_weight)
+
+ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
+ """Load weights."""
+ # modify from vllm
+ stacked_params_mapping = [
+ # (param_name, shard_name, shard_id)
+ ('.qkv_proj', '.q_proj', 'q'),
+ ('.qkv_proj', '.k_proj', 'k'),
+ ('.qkv_proj', '.v_proj', 'v'),
+ ('.gate_up_proj', '.gate_proj', 0),
+ ('.gate_up_proj', '.up_proj', 1),
+ ]
+
+ # expert map
+ num_experts = self.config.num_experts
+ expert_params_mapping = []
+ for exp_id in range(num_experts):
+ gate_param = ('.experts.gate_up', f'.experts.{exp_id}.gate_proj', exp_id, 'gate')
+ up_param = ('.experts.gate_up', f'.experts.{exp_id}.up_proj', exp_id, 'up')
+ down_param = ('.experts.down', f'.experts.{exp_id}.down_proj', exp_id, 'down')
+ expert_params_mapping += [gate_param, up_param, down_param]
+
+ params_dict = dict(self.named_parameters())
+ for name, loaded_weight in weights:
+ if 'rotary_emb.inv_freq' in name:
+ continue
+ if ('rotary_emb.cos_cached' in name or 'rotary_emb.sin_cached' in name):
+ continue
+ if self.config.tie_word_embeddings and 'lm_head.weight' in name:
+ continue
+
+ if '.experts' in name:
+ self._load_weight_experts(name, loaded_weight, params_dict, expert_params_mapping=expert_params_mapping)
+ else:
+ for (param_name, weight_name, shard_id) in stacked_params_mapping:
+ if weight_name not in name:
+ continue
+ name = name.replace(weight_name, param_name)
+ param = params_dict[name]
+ load_weight(param, loaded_weight, shard_id=shard_id)
+ break
+ else:
+ param = params_dict[name]
+ load_weight(param, loaded_weight)
diff --git a/lmdeploy/pytorch/models/utils/cudagraph.py b/lmdeploy/pytorch/models/utils/cudagraph.py
index b86adbbb89..9901bf443c 100644
--- a/lmdeploy/pytorch/models/utils/cudagraph.py
+++ b/lmdeploy/pytorch/models/utils/cudagraph.py
@@ -142,15 +142,15 @@ def fill_buffers_cudagraph(self, graph_meta: CudaGraphMeta, input_ids: Tensor, p
new_inputs['cross_attn_metadata'] = cross_attn_metadata
if is_decoding:
- new_inputs['input_ids'] = input_buffers['input_ids'][:, :new_batch_size]
- new_inputs['position_ids'] = input_buffers['position_ids'][:, :new_batch_size]
+ new_inputs['input_ids'] = input_buffers['input_ids'][:, :num_tokens]
+ new_inputs['position_ids'] = input_buffers['position_ids'][:, :num_tokens]
else:
new_inputs['input_ids'] = input_buffers['input_ids']
new_inputs['position_ids'] = input_buffers['position_ids']
if inputs_embeds is not None:
if is_decoding:
- new_inputs['inputs_embeds'] = input_buffers['inputs_embeds'][:, :new_batch_size]
+ new_inputs['inputs_embeds'] = input_buffers['inputs_embeds'][:, :num_tokens]
else:
new_inputs['inputs_embeds'] = input_buffers['inputs_embeds']
diff --git a/lmdeploy/pytorch/nn/attention.py b/lmdeploy/pytorch/nn/attention.py
index ef5d831d51..7a1654db4b 100644
--- a/lmdeploy/pytorch/nn/attention.py
+++ b/lmdeploy/pytorch/nn/attention.py
@@ -37,6 +37,7 @@ def __init__(
causal: bool = True,
use_flash_mla: bool = False,
learnable_sink: bool = False,
+ block_sparse_size: int = 1,
**kwargs,
):
super().__init__()
@@ -61,6 +62,7 @@ def __init__(
causal=causal,
use_flash_mla=use_flash_mla,
learnable_sink=learnable_sink,
+ block_sparse_size=block_sparse_size,
**kwargs,
)
diff --git a/lmdeploy/pytorch/paging/block_manager/default_block_manager.py b/lmdeploy/pytorch/paging/block_manager/default_block_manager.py
index cc41187900..3348f2f8a5 100644
--- a/lmdeploy/pytorch/paging/block_manager/default_block_manager.py
+++ b/lmdeploy/pytorch/paging/block_manager/default_block_manager.py
@@ -25,7 +25,7 @@ class DefaultBlockManager(BaseBlockManager):
@classmethod
def num_required_blocks(cls, obj: SchedulerSequence, prealloc_size: int = 0):
"""Get num required blocks."""
- num_tokens = obj.num_all_tokens() + prealloc_size
+ num_tokens = obj.num_all_ids + prealloc_size
# cross tokens
num_cross = obj.num_all_cross_tokens()
diff --git a/lmdeploy/pytorch/paging/block_manager/window_block_manager.py b/lmdeploy/pytorch/paging/block_manager/window_block_manager.py
index 3f5ca9605e..1d91c020cb 100644
--- a/lmdeploy/pytorch/paging/block_manager/window_block_manager.py
+++ b/lmdeploy/pytorch/paging/block_manager/window_block_manager.py
@@ -10,10 +10,10 @@
def _num_blocks_to_drop(seq: SchedulerSequence, window_size: int):
"""Num blocks to free."""
- if seq.history_len <= window_size:
+ history_len = seq.num_history_ids
+ if seq.num_history_ids <= window_size:
return 0
block_size = seq.block_size
- history_len = seq.history_len
num_blocks = len(seq.logical_blocks)
win_start_block_id = (history_len - window_size) // block_size
win_end_block_id = (history_len - 1) // block_size
diff --git a/lmdeploy/pytorch/paging/block_trie.py b/lmdeploy/pytorch/paging/block_trie.py
index b5908dde54..d00690af24 100644
--- a/lmdeploy/pytorch/paging/block_trie.py
+++ b/lmdeploy/pytorch/paging/block_trie.py
@@ -81,7 +81,7 @@ def __match_success(node: Node):
curr = node
num_matched += block_size
- while num_matched + block_size < seq.num_all_ids:
+ while num_matched + block_size < seq.num_valid_ids:
curr_tokens = seq.history_cache[num_matched:num_matched + block_size]
key = hash(('random', tuple(curr_tokens)))
@@ -116,9 +116,9 @@ def allocate(self, seq: SchedulerSequence):
logical_blocks.last_shared_node = node
num_matched = node.num_matched
- num_all_ids = seq.num_all_ids
+ num_valid_ids = seq.num_valid_ids
- if num_matched + block_size > num_all_ids:
+ if num_matched + block_size > num_valid_ids:
return
if len(node.children) == 0 and node.parent is not None:
@@ -127,7 +127,7 @@ def allocate(self, seq: SchedulerSequence):
block_id = num_matched // block_size
blocks = []
free_blocks = []
- while num_matched + block_size <= num_all_ids:
+ while num_matched + block_size <= num_valid_ids:
curr_tokens = seq.history_cache[num_matched:num_matched + block_size]
block = logical_blocks[block_id]
diff --git a/lmdeploy/pytorch/paging/scheduler.py b/lmdeploy/pytorch/paging/scheduler.py
index 0be35ab9be..8144ac52a1 100644
--- a/lmdeploy/pytorch/paging/scheduler.py
+++ b/lmdeploy/pytorch/paging/scheduler.py
@@ -9,7 +9,7 @@
from lmdeploy.utils import get_logger, logging_timer
from ..config import CacheConfig, SchedulerConfig
-from ..messages import MessageStatus, SchedulerSequence, SchedulerSession, SequenceManager
+from ..messages import MessageStatus, SchedulerSequence, SchedulerSession, SequenceManager, SequenceMeta
from .block_manager import build_block_manager
from .block_trie import BlockTrie
@@ -36,7 +36,10 @@ class Scheduler:
cache_config (CacheConfig): The config of cache info.
"""
- def __init__(self, scheduler_config: SchedulerConfig, cache_config: CacheConfig) -> None:
+ def __init__(self,
+ scheduler_config: SchedulerConfig,
+ cache_config: CacheConfig,
+ seq_meta: SequenceMeta = None) -> None:
self.scheduler_config = scheduler_config
self.cache_config = cache_config
@@ -50,7 +53,8 @@ def __init__(self, scheduler_config: SchedulerConfig, cache_config: CacheConfig)
self.eviction_helper = self.build_eviction_helper(self.scheduler_config.eviction_type)
- self.seq_manager = SequenceManager()
+ seq_meta = seq_meta or SequenceMeta(self.cache_config.block_size)
+ self.seq_manager = SequenceManager(seq_meta)
@property
def waiting(self):
@@ -121,7 +125,7 @@ def add_session(self, session_id: int):
session_id (int): New session id.
"""
assert session_id not in self.sessions
- session = SchedulerSession(session_id, self.cache_config.block_size, seq_manager=self.seq_manager)
+ session = SchedulerSession(session_id, seq_manager=self.seq_manager)
self.sessions[session_id] = session
return session
@@ -179,7 +183,7 @@ def _reorder_migrating():
return running_migration
@logging_timer('SchedulePrefilling', logger)
- def _schedule_prefill(self):
+ def _schedule_prefill(self, prealloc_size: int = 0):
"""Schedule for prefilling."""
max_batches = self.scheduler_config.max_batches - self.num_running() - self.num_locked()
@@ -203,7 +207,7 @@ def __evict_for_seq(seq: SchedulerSequence, waiting):
hanging = reversed(self.hanging)
waiting = reversed(waiting)
evictable = list(chain(hanging, waiting))
- return eviction_helper.evict_for_seq(seq, evictable, 0)
+ return eviction_helper.evict_for_seq(seq, evictable, prealloc_size)
def _reorder_waiting():
"""Reorder waiting."""
@@ -226,7 +230,7 @@ def _reorder_waiting():
break
# allocate session memory
- self.block_manager.allocate(seq)
+ self.block_manager.allocate(seq, prealloc_size)
_to_running(seq)
seq.record_event(EventType.SCHEDULED)
@@ -287,7 +291,7 @@ def __evict_for_seq(seq: SchedulerSequence, num_required_blocks: int):
def schedule(self, is_prefill: bool, prealloc_size: int = 0):
"""Schedule inputs for next steps."""
if is_prefill:
- output = self._schedule_prefill()
+ output = self._schedule_prefill(0)
else:
output = self._schedule_decoding(prealloc_size)
running, swap_in_map, swap_out_map, copy_map = output
diff --git a/lmdeploy/pytorch/strategies/__init__.py b/lmdeploy/pytorch/strategies/__init__.py
new file mode 100644
index 0000000000..c0f5da1262
--- /dev/null
+++ b/lmdeploy/pytorch/strategies/__init__.py
@@ -0,0 +1,16 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+from lmdeploy.pytorch.config import MiscConfig, ModelConfig
+
+
+def build_strategy_factory(model_config: ModelConfig, misc_config: MiscConfig):
+ """Build strategy factory."""
+ model_paradigm = model_config.model_paradigm
+
+ if model_paradigm == 'ar':
+ from .ar import ARStrategyFactory
+ return ARStrategyFactory(model_config=model_config)
+ elif model_paradigm == 'dllm':
+ from .dllm import DLLMStrategyFactory
+ return DLLMStrategyFactory(model_config=model_config, dllm_config=misc_config.dllm_config)
+ else:
+ raise RuntimeError(f'Unsupported model paradigm: {model_paradigm}')
diff --git a/lmdeploy/pytorch/strategies/ar/__init__.py b/lmdeploy/pytorch/strategies/ar/__init__.py
new file mode 100644
index 0000000000..b593107c2e
--- /dev/null
+++ b/lmdeploy/pytorch/strategies/ar/__init__.py
@@ -0,0 +1,54 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+from typing import TYPE_CHECKING
+
+from lmdeploy.pytorch.config import ModelConfig
+from lmdeploy.pytorch.strategies.base.sequence import SequenceStrategy
+
+if TYPE_CHECKING:
+ from lmdeploy.pytorch.strategies.base.cudagraph import CudagraphStrategy
+ from lmdeploy.pytorch.strategies.base.model_inputs import ModelInputsStrategy
+ from lmdeploy.pytorch.strategies.base.sampling import SamplingStrategy
+ from lmdeploy.pytorch.strategies.base.model_agent import ModelAgentStrategy
+ from lmdeploy.pytorch.strategies.base.engine import EngineStrategy
+ from lmdeploy.pytorch.config import CacheConfig, SchedulerConfig
+
+from ..base import StrategyFactoryBase
+
+
+class ARStrategyFactory(StrategyFactoryBase):
+
+ def __init__(self, model_config: ModelConfig):
+ """config."""
+ self.model_config = model_config
+
+ def build_cudagraph_strategy(self) -> 'CudagraphStrategy':
+ """Build cudagraph strategy."""
+ from .cudagraph import ARCudagraphStrategy
+ return ARCudagraphStrategy()
+
+ def build_sampling_strategy(self) -> 'SamplingStrategy':
+ """Build sampling strategy."""
+ from .sampling import ARSamplingStrategy
+ pad_token_id = self.model_config.bos_token_id
+ pad_token_id = 0 if pad_token_id is None else pad_token_id
+ return ARSamplingStrategy(pad_token_id)
+
+ def build_model_inputs_strategy(self) -> 'ModelInputsStrategy':
+ """Build model inputs strategy."""
+ from .model_inputs import ARModelInputsStrategy
+ return ARModelInputsStrategy()
+
+ def build_model_agent_strategy(self) -> 'ModelAgentStrategy':
+ """Build model agent strategy."""
+ from .model_agent import ARModelAgentStrategy
+ return ARModelAgentStrategy()
+
+ def build_engine_strategy(self, cache_config: 'CacheConfig',
+ scheduler_config: 'SchedulerConfig') -> 'EngineStrategy':
+ """Build engine strategy."""
+ from .engine import AREngineStrategy
+ return AREngineStrategy(cache_config=cache_config, scheduler_config=scheduler_config)
+
+ def build_sequence_strategy(self) -> SequenceStrategy:
+ from .sequence import ARSequenceStrategy
+ return ARSequenceStrategy()
diff --git a/lmdeploy/pytorch/strategies/ar/cudagraph.py b/lmdeploy/pytorch/strategies/ar/cudagraph.py
new file mode 100644
index 0000000000..e3749bcfc2
--- /dev/null
+++ b/lmdeploy/pytorch/strategies/ar/cudagraph.py
@@ -0,0 +1,9 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+from ..base.cudagraph import CudagraphStrategy
+
+
+class ARCudagraphStrategy(CudagraphStrategy):
+
+ def get_max_tokens(self, batch_size: int) -> int:
+ """Get max tokens."""
+ return batch_size
diff --git a/lmdeploy/pytorch/strategies/ar/engine.py b/lmdeploy/pytorch/strategies/ar/engine.py
new file mode 100644
index 0000000000..60ae69f8c9
--- /dev/null
+++ b/lmdeploy/pytorch/strategies/ar/engine.py
@@ -0,0 +1,20 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+from lmdeploy.pytorch.config import CacheConfig, SchedulerConfig
+
+from ..base.engine import EngineStrategy
+
+
+class AREngineStrategy(EngineStrategy):
+ """AR Engine Strategy."""
+
+ def __init__(self, scheduler_config: SchedulerConfig, cache_config: CacheConfig) -> None:
+ self.scheduler_config = scheduler_config
+ self.cache_config = cache_config
+
+ def get_prealloc_size(self, is_decoding: bool):
+ """Get prealloc_size."""
+ return self.scheduler_config.prefill_interval if is_decoding else 0
+
+ def get_num_loops(self, is_decoding: bool) -> int:
+ """Get num_loops."""
+ return self.scheduler_config.prefill_interval if is_decoding else 1
diff --git a/lmdeploy/pytorch/strategies/ar/model_agent.py b/lmdeploy/pytorch/strategies/ar/model_agent.py
new file mode 100644
index 0000000000..4096db2cb7
--- /dev/null
+++ b/lmdeploy/pytorch/strategies/ar/model_agent.py
@@ -0,0 +1,108 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+from dataclasses import dataclass
+from typing import Any, List, Optional
+
+import torch
+from torch.profiler import record_function
+
+from lmdeploy.pytorch.engine.logits_process import SamplingInputs
+from lmdeploy.pytorch.messages import SchedulerSequence
+from lmdeploy.pytorch.model_inputs import ModelInputs
+
+from ..base.model_agent import ExtraInputs, ExtraOutputs, ModelAgentStrategy, StoppingCriteria
+
+SeqList = List[SchedulerSequence]
+
+
+class ARExtraInputs(ExtraInputs):
+ """Ar extra inputs."""
+
+
+class ARExtraOutputs(ExtraOutputs):
+ """Ar extra outputs."""
+
+
+@dataclass
+class ARStoppingCriteria(StoppingCriteria):
+ num_appendable_ids: torch.Tensor
+
+ @record_function('stopping_criteria')
+ def step(self,
+ token_ids: torch.Tensor,
+ stop_words: torch.Tensor,
+ inputs: Optional[ModelInputs] = None,
+ extra_inputs: Optional[ARExtraInputs] = None):
+ """Check whether to stop generation."""
+ num_appendable_ids = self.num_appendable_ids - 1
+ stopped = num_appendable_ids <= 0
+ stop_pos = torch.zeros_like(num_appendable_ids)
+ if stop_words is not None:
+ sw_stopped = (token_ids[:, None] == stop_words).any(1)
+ stopped = stopped | sw_stopped
+ one_ids = torch.clamp_max(num_appendable_ids, 0)
+ num_appendable_ids = torch.where(sw_stopped, one_ids, num_appendable_ids)
+
+ # I don't know why assign inplace does not works...
+ new_stopping = ARStoppingCriteria(num_appendable_ids=num_appendable_ids)
+ return stopped, stop_pos, new_stopping
+
+
+class ARModelAgentStrategy(ModelAgentStrategy):
+
+ def slice_outputs(self, inputs: torch.Tensor, seq_length: torch.LongTensor) -> torch.Tensor:
+ """Slice outputs."""
+ # batch size == 1
+ if len(seq_length) == 1:
+ return inputs[-1:]
+
+ if len(seq_length) == inputs.size(0):
+ return inputs
+ last_idx = seq_length.cumsum(-1) - 1
+ return inputs[last_idx]
+
+ def slice_extra_inputs(self, extra_inputs: ARExtraInputs, seq_length: torch.LongTensor) -> ARExtraInputs:
+ """Slice outputs."""
+ return extra_inputs
+
+ def _step_sampling_inputs(self, sampling_inputs: SamplingInputs, next_token_ids: torch.Tensor):
+ """step."""
+ sampling_inputs.num_ignore_eos = sampling_inputs.num_ignore_eos - 1
+
+ all_ids = sampling_inputs.all_ids
+ if all_ids is not None:
+ sampling_inputs.all_ids = torch.cat([all_ids, next_token_ids[:, None]], 1)
+
+ guided_input_ids = sampling_inputs.guided_input_ids
+ if guided_input_ids is not None:
+ sampling_inputs.guided_input_ids = torch.cat([guided_input_ids, next_token_ids[:, None]], 1)
+
+ return sampling_inputs
+
+ def make_stopping_criteria(self, seqs: SeqList) -> ARStoppingCriteria:
+ """Create stopping criteria."""
+ num_appendable = [seq.sampling_param.max_new_tokens - seq.num_new_tokens for seq in seqs]
+ num_appendable = torch.tensor(num_appendable)
+ return ARStoppingCriteria(num_appendable_ids=num_appendable)
+
+ def make_extra_inputs(self, seqs: 'SeqList') -> ExtraInputs:
+ """Create extra inputs."""
+ return ARExtraInputs()
+
+ def make_extra_outputs(self, extra_inputs: ARExtraInputs) -> ARExtraOutputs:
+ """Create extra outputs."""
+ return ARExtraOutputs()
+
+ def update_inputs_for_next_step(self, model_inputs: 'ModelInputs', sampling_inputs: 'SamplingInputs',
+ next_token_ids: torch.Tensor, model_metas: Any, extra_inputs: ARExtraInputs,
+ **kwargs):
+ """Step next inputs."""
+ model_inputs.model_metas = model_metas
+ step_seqlens = model_inputs.seq_length
+ model_inputs.step(next_token_ids, step_seqlens)
+ self._step_sampling_inputs(sampling_inputs, next_token_ids)
+ return model_inputs, extra_inputs
+
+ def post_sampling(self, inputs: 'ModelInputs', logits: torch.Tensor, next_token_ids: torch.LongTensor,
+ extra_inputs: ARExtraInputs):
+ """Post sampling."""
+ return next_token_ids, extra_inputs
diff --git a/lmdeploy/pytorch/strategies/ar/model_inputs.py b/lmdeploy/pytorch/strategies/ar/model_inputs.py
new file mode 100644
index 0000000000..5b263b1e37
--- /dev/null
+++ b/lmdeploy/pytorch/strategies/ar/model_inputs.py
@@ -0,0 +1,21 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+from lmdeploy.pytorch.model_inputs import ModelInputs
+
+from ..base.model_inputs import ModelInputsStrategy, make_dummy_inputs
+
+
+class ARModelInputsStrategy(ModelInputsStrategy):
+
+ def make_dummy(self,
+ batch_size: int,
+ is_decoding: bool,
+ device: str = 'cpu',
+ dummy_block_id: int = 0,
+ vocab_size: int = 1) -> ModelInputs:
+ """Create dummy model inputs."""
+ return make_dummy_inputs(batch_size,
+ max_q_seqlen=1,
+ is_decoding=is_decoding,
+ device=device,
+ dummy_block_id=dummy_block_id,
+ vocab_size=vocab_size)
diff --git a/lmdeploy/pytorch/strategies/ar/sampling.py b/lmdeploy/pytorch/strategies/ar/sampling.py
new file mode 100644
index 0000000000..b2516f091a
--- /dev/null
+++ b/lmdeploy/pytorch/strategies/ar/sampling.py
@@ -0,0 +1,191 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+from typing import List
+
+import torch
+
+from lmdeploy.pytorch.engine.logits_process import SamplingInputs
+from lmdeploy.pytorch.messages import SchedulerSequence
+
+from ..base.sampling import SamplingStrategy
+
+SeqList = List[SchedulerSequence]
+
+
+def _gather_all_ids(pad_id: int, seqs: SeqList, sampling_inputs: SamplingInputs):
+ """Gather history."""
+ if sampling_inputs.repetition_penalty is None and not any(sampling_inputs.logits_processors):
+ return None
+ batch = len(seqs)
+ max_len = max(seq.num_valid_ids for seq in seqs)
+ output = torch.full((batch, max_len), pad_id, dtype=torch.int64)
+ for idx, seq in enumerate(seqs):
+ h_len = seq.num_valid_ids
+ if h_len == 0:
+ continue
+ h_ids = torch.from_numpy(seq.valid_ids)
+ output[idx, -h_len:] = h_ids
+ return output
+
+
+def _gather_guided_input_ids(pad_id: int, seqs: SeqList, sampling_inputs: 'SamplingInputs'):
+ """Gather input ids for guided decode."""
+ if not any(sampling_inputs.response_formats or ()):
+ return None
+ batch = len(seqs)
+ max_len = max(seq.num_new_tokens for seq in seqs)
+ output = torch.full((batch, max_len), pad_id, dtype=torch.int64)
+ for idx, seq in enumerate(seqs):
+ h_len = seq.num_new_tokens
+ if h_len == 0:
+ continue
+ h_ids = torch.from_numpy(seq.generated_ids)
+ output[idx, -h_len:] = h_ids
+ return output
+
+
+def _get_num_ignore_eos(seqs: SeqList):
+ """Get num ignore eos."""
+ ret = [seq.sampling_param.min_new_tokens - seq.num_new_tokens for seq in seqs]
+ return torch.tensor(ret)
+
+
+class ARSamplingStrategy(SamplingStrategy):
+ """Sampling strategy for autoregressive models."""
+
+ def __init__(self, pad_token_id: int) -> None:
+ pad_token_id = 0 if pad_token_id is None else pad_token_id
+ self.pad_token_id = pad_token_id
+
+ def make_sampling_inputs(self, seqs: SeqList) -> SamplingInputs:
+ """Create sampling inputs from the sequences."""
+ batch_size = len(seqs)
+ temperature = [None] * batch_size
+ repetition_penalty = [None] * batch_size
+ top_k = [None] * batch_size
+ top_p = [None] * batch_size
+ min_p = [None] * batch_size
+ bad_words = [None] * batch_size
+ stop_words = [None] * batch_size
+ random_seeds = [torch.seed() & 0xffffffff] * batch_size
+ random_offsets = [None] * batch_size
+ response_formats = [None] * batch_size
+ logits_processors = [None] * batch_size
+ num_logprobs = [None] * batch_size
+
+ def __gather_params():
+ """Gather params."""
+ for idx, seq in enumerate(seqs):
+ param = seq.sampling_param
+ temperature[idx] = param.temperature
+ repetition_penalty[idx] = param.repetition_penalty
+ top_k[idx] = param.top_k
+ top_p[idx] = param.top_p
+ min_p[idx] = param.min_p
+ random_offsets[idx] = seq.num_valid_ids
+ response_formats[idx] = param.response_format
+ if param.random_seed is not None:
+ random_seeds[idx] = param.random_seed & 0xffffffff
+
+ bw = param.bad_words
+ sw = param.stop_words
+ if (not param.ignore_eos and seq.num_new_tokens < param.min_new_tokens):
+ bw = bw + sw
+ bad_words[idx] = bw
+ stop_words[idx] = sw
+ logits_processors[idx] = param.logits_processors
+ num_logprobs[idx] = param.num_logprobs
+
+ def __get_topp(top_p):
+ """Get topp."""
+ min_top_p = min(top_p)
+ if min_top_p == 1.0:
+ top_p = None
+ else:
+ top_p = torch.tensor(top_p)
+ return top_p, min_top_p
+
+ def __get_minp(min_p):
+ """Get minp."""
+ max_min_p = max(min_p)
+ if max_min_p == 0.0:
+ min_p = None
+ else:
+ min_p = torch.Tensor(min_p)
+ return min_p
+
+ def __get_bad_words(bad_words):
+ """Get bad words."""
+ max_bw_len = max(len(bw) for bw in bad_words)
+ if max_bw_len == 0:
+ return None, None
+ if all(len(bw) == max_bw_len for bw in bad_words):
+ ret = torch.tensor(bad_words)
+ mask = torch.ones_like(ret, dtype=bool)
+ return ret, mask
+ ret = torch.full((batch_size, max_bw_len), -1, dtype=torch.int64)
+ for idx, bw in enumerate(bad_words):
+ bw_len = len(bw)
+ if bw_len == 0:
+ continue
+ bw = ret.new_tensor(bw)
+ ret[idx, :bw_len] = bw
+
+ mask = ret >= 0
+ ret = ret.where(mask, 0)
+ return ret, mask
+
+ __gather_params()
+
+ if all(rp == 1.0 for rp in repetition_penalty):
+ repetition_penalty = None
+ else:
+ repetition_penalty = torch.tensor(repetition_penalty)
+
+ temperature = torch.tensor(temperature)
+
+ bad_words, bad_mask = __get_bad_words(bad_words)
+ stop_words, stop_mask = __get_bad_words(stop_words)
+
+ max_top_k = max(top_k)
+ if min(top_k) <= 0:
+ max_top_k = 0
+ if max_top_k == 1:
+ top_k = None
+ top_p, min_top_p = None, 1.0
+ min_p = None
+ random_seeds = None
+ random_offsets = None
+ else:
+ top_k = torch.tensor(top_k)
+ top_p, min_top_p = __get_topp(top_p)
+ min_p = __get_minp(min_p)
+ random_seeds = torch.tensor(random_seeds)
+ random_offsets = torch.tensor(random_offsets)
+
+ max_num_logprobs = max(num_logprobs)
+
+ sampling_input = SamplingInputs(
+ temperature=temperature,
+ bad_words=bad_words,
+ bad_mask=bad_mask,
+ stop_words=stop_words,
+ stop_mask=stop_mask,
+ repetition_penalty=repetition_penalty,
+ top_k=top_k,
+ top_p=top_p,
+ min_p=min_p,
+ random_seeds=random_seeds,
+ random_offsets=random_offsets,
+ response_formats=tuple(response_formats),
+ max_top_k=max_top_k,
+ min_top_p=min_top_p,
+ logits_processors=logits_processors,
+ max_num_logprobs=max_num_logprobs,
+ batch_size=batch_size,
+ )
+
+ pad_token_id = self.pad_token_id
+ sampling_input.all_ids = _gather_all_ids(pad_token_id, seqs, sampling_input)
+ sampling_input.guided_input_ids = _gather_guided_input_ids(pad_token_id, seqs, sampling_input)
+ sampling_input.num_ignore_eos = _get_num_ignore_eos(seqs)
+ return sampling_input
diff --git a/lmdeploy/pytorch/strategies/ar/sequence.py b/lmdeploy/pytorch/strategies/ar/sequence.py
new file mode 100644
index 0000000000..91a3335f18
--- /dev/null
+++ b/lmdeploy/pytorch/strategies/ar/sequence.py
@@ -0,0 +1,113 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import time
+from dataclasses import dataclass
+from typing import Any, Dict, List, Optional
+
+from torch import Tensor
+
+from lmdeploy.pytorch.disagg.conn.protocol import MigrationRequest
+from lmdeploy.pytorch.engine.model_agent import BatchedOutputs
+from lmdeploy.pytorch.messages import (InputEmbeddings, MessageStatus, MultiModalInputs, SamplingParam,
+ SchedulerSequence, SchedulerSession, UpdateTokenMode, _to_ndarray)
+
+from ..base.sequence import SequenceStrategy
+
+SeqList = List[SchedulerSequence]
+
+
+@dataclass
+class SchedulerSequenceDefault(SchedulerSequence):
+
+ def update_token_ids(self,
+ token_ids: Tensor,
+ multimodals: MultiModalInputs = None,
+ embeddings: List[InputEmbeddings] = None,
+ model_meta: Dict[str, Any] = None,
+ mode: UpdateTokenMode = UpdateTokenMode.INPUTS,
+ **kwargs):
+ """Update token ids, old token ids will be added to history."""
+ # update history image nums
+ self._update_embeddings(embeddings)
+
+ # update multimodals
+ self._update_multimodals(multimodals)
+
+ token_ids = _to_ndarray(token_ids)
+
+ num_valid = len(token_ids)
+
+ if mode == UpdateTokenMode.INPUTS:
+ self.arrive_time = time.perf_counter()
+ self.output_start_pos = self.num_all_ids + len(token_ids)
+ self._num_token_ids += num_valid
+ self.num_new_tokens = 0
+ else:
+ self._num_history_ids += self._num_token_ids
+ num_token_ids = num_valid
+ self._num_token_ids = num_token_ids
+ self.num_new_tokens += num_token_ids
+
+ self.history_cache.append(token_ids)
+
+ if model_meta is not None:
+ self.model_meta = model_meta
+
+ def set_step(self, step: int):
+ """Set step."""
+ num_all_ids = self.num_all_ids
+ # update step for vlm
+ if len(self.history_embeddings) > 0:
+ new_step, self._num_history_images, self._num_images = \
+ self.history_embeddings.get_step(step)
+ assert 0 <= new_step <= step
+ step = new_step
+ self._num_history_ids = step
+ self._num_token_ids = num_all_ids - step
+ self.num_ignored_history = min(step, self.num_ignored_history)
+
+ self.model_meta = None
+
+ # cross
+ if self.history_multimodals is not None:
+ self._num_history_cross = self.history_multimodals.get_encoder_len(0, self.num_history_ids)
+ self._num_cross = self.history_multimodals.get_encoder_len(self._num_history_ids, num_all_ids)
+
+
+class ARSequenceStrategy(SequenceStrategy):
+
+ def make_sequence(self,
+ seq_id: int,
+ session: 'SchedulerSession',
+ sampling_param: 'SamplingParam' = None,
+ adapter_name: str = None,
+ migration_request: Optional[MigrationRequest] = None,
+ resp_cache: bool = False,
+ preserve_cache: bool = False) -> 'SchedulerSequence':
+ """Make sequence."""
+ return SchedulerSequenceDefault(seq_id=seq_id,
+ session=session,
+ sampling_param=sampling_param,
+ adapter_name=adapter_name,
+ migration_request=migration_request,
+ resp_cache=resp_cache,
+ preserve_cache=preserve_cache)
+
+ def update_running(self, running: SeqList, batched_outputs: BatchedOutputs, is_decoding: bool) -> None:
+ """Update running sequences."""
+ next_token_ids = batched_outputs.next_token_ids
+ stopped = batched_outputs.stopped
+ stopped = stopped.tolist()
+ model_metas = batched_outputs.model_metas
+ if model_metas is None:
+ model_metas = [None] * len(running)
+
+ next_token_ids = next_token_ids.numpy()
+ update_mode = UpdateTokenMode.DECODE if is_decoding else UpdateTokenMode.PREFILL
+ for token, msg, stop, model_meta in zip(next_token_ids, running, stopped, model_metas):
+ if msg.status != MessageStatus.LOCKED:
+ continue
+
+ # fill token
+ msg.update_token_ids(token, model_meta=model_meta, mode=update_mode)
+ if stop:
+ msg.status = MessageStatus.TO_BE_MIGRATED if msg.preserve_cache else MessageStatus.STOPPED
diff --git a/lmdeploy/pytorch/strategies/base/__init__.py b/lmdeploy/pytorch/strategies/base/__init__.py
new file mode 100644
index 0000000000..42519779ad
--- /dev/null
+++ b/lmdeploy/pytorch/strategies/base/__init__.py
@@ -0,0 +1,47 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+from abc import ABC, abstractmethod
+from typing import TYPE_CHECKING
+
+if TYPE_CHECKING:
+ from lmdeploy.pytorch.config import CacheConfig, SchedulerConfig
+
+ from .cudagraph import CudagraphStrategy
+ from .engine import EngineStrategy
+ from .model_agent import ModelAgentStrategy
+ from .model_inputs import ModelInputsStrategy
+ from .sampling import SamplingStrategy
+ from .sequence import SequenceStrategy
+
+
+class StrategyFactoryBase(ABC):
+
+ @abstractmethod
+ def build_cudagraph_strategy(self) -> 'CudagraphStrategy':
+ """Build cudagraph strategy."""
+ pass
+
+ @abstractmethod
+ def build_sampling_strategy(self) -> 'SamplingStrategy':
+ """Build sampling strategy."""
+ pass
+
+ @abstractmethod
+ def build_model_inputs_strategy(self) -> 'ModelInputsStrategy':
+ """Build model inputs strategy."""
+ pass
+
+ @abstractmethod
+ def build_model_agent_strategy(self) -> 'ModelAgentStrategy':
+ """Build model agent strategy."""
+ pass
+
+ @abstractmethod
+ def build_engine_strategy(self, cache_config: 'CacheConfig',
+ scheduler_config: 'SchedulerConfig') -> 'EngineStrategy':
+ """Build engine strategy."""
+ pass
+
+ @abstractmethod
+ def build_sequence_strategy(self) -> 'SequenceStrategy':
+ """Build sequence strategy."""
+ pass
diff --git a/lmdeploy/pytorch/strategies/base/cudagraph.py b/lmdeploy/pytorch/strategies/base/cudagraph.py
new file mode 100644
index 0000000000..795c3b5350
--- /dev/null
+++ b/lmdeploy/pytorch/strategies/base/cudagraph.py
@@ -0,0 +1,10 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+from abc import ABC, abstractmethod
+
+
+class CudagraphStrategy(ABC):
+
+ @abstractmethod
+ def get_max_tokens(self, batch_size: int) -> int:
+ """Get max tokens."""
+ pass
diff --git a/lmdeploy/pytorch/strategies/base/engine.py b/lmdeploy/pytorch/strategies/base/engine.py
new file mode 100644
index 0000000000..1c8e0c4ef7
--- /dev/null
+++ b/lmdeploy/pytorch/strategies/base/engine.py
@@ -0,0 +1,16 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+from abc import ABC, abstractmethod
+
+
+class EngineStrategy(ABC):
+ """Engine strategy."""
+
+ @abstractmethod
+ def get_prealloc_size(self, is_decoding: bool) -> int:
+ """Get prealloc_size."""
+ pass
+
+ @abstractmethod
+ def get_num_loops(self, is_decoding: bool) -> int:
+ """Get num_loops."""
+ pass
diff --git a/lmdeploy/pytorch/strategies/base/model_agent.py b/lmdeploy/pytorch/strategies/base/model_agent.py
new file mode 100644
index 0000000000..8701a256fa
--- /dev/null
+++ b/lmdeploy/pytorch/strategies/base/model_agent.py
@@ -0,0 +1,132 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+from abc import ABC, abstractmethod
+from dataclasses import dataclass, fields
+from typing import TYPE_CHECKING, Any, List, Optional
+
+import numpy as np
+import torch
+
+if TYPE_CHECKING:
+ from lmdeploy.pytorch.engine.logits_process import SamplingInputs
+ from lmdeploy.pytorch.messages import SchedulerSequence
+ from lmdeploy.pytorch.model_inputs import ModelInputs
+ SeqList = List[SchedulerSequence]
+
+
+def to_device(self, device: str, non_blocking: bool = False):
+ """To device."""
+ out_dict = dict()
+ for f in fields(self):
+ k = f.name
+ v = getattr(self, k)
+ if isinstance(v, torch.Tensor):
+ v = v.to(device, non_blocking=non_blocking)
+ out_dict[k] = v
+
+ return type(self)(**out_dict)
+
+
+@dataclass
+class ExtraInputs(ABC):
+
+ def to_device(self, device: str, non_blocking: bool = False):
+ """To device."""
+ return to_device(self, device, non_blocking)
+
+
+@dataclass
+class ExtraOutputs(ABC):
+
+ def to_device(self, device: str, non_blocking: bool = False):
+ """To device."""
+ return to_device(self, device, non_blocking)
+
+ def to_cpu(self):
+ """To cpu."""
+ return self.to_device('cpu', non_blocking=False)
+
+ def to_numpy(self):
+ """To numpy."""
+ out = dict()
+ for f in fields(self):
+ k = f.name
+ v = getattr(self, k)
+ if isinstance(v, torch.Tensor) and v.dtype != torch.bfloat16:
+ v = v.detach().numpy()
+ elif hasattr(v, 'to_numpy'):
+ v = v.to_numpy()
+ out[k] = v
+ return type(self)(**out)
+
+ def to_tensor(self):
+ """To tensor."""
+ out = dict()
+ for f in fields(self):
+ k = f.name
+ v = getattr(self, k)
+ if isinstance(v, np.ndarray):
+ v = torch.from_numpy(v)
+ elif hasattr(v, 'to_tensor'):
+ v = v.to_tensor()
+ out[k] = v
+ return type(self)(**out)
+
+
+@dataclass
+class StoppingCriteria(ABC):
+ """Base class for stopping criteria."""
+
+ @abstractmethod
+ def step(self,
+ token_ids: torch.Tensor,
+ stop_words: torch.Tensor,
+ inputs: Optional['ModelInputs'] = None,
+ extra_inputs: Optional[ExtraInputs] = None):
+ """Check whether to stop generation."""
+ pass
+
+ def to_device(self, device: str, non_blocking: bool = False):
+ """To device."""
+ return to_device(self, device, non_blocking)
+
+
+class ModelAgentStrategy(ABC):
+ """Base class for model agent strategies."""
+
+ @abstractmethod
+ def slice_outputs(self, inputs: torch.Tensor, seq_length: torch.LongTensor) -> torch.Tensor:
+ """Slice outputs."""
+ pass
+
+ @abstractmethod
+ def slice_extra_inputs(self, extra_inputs: ExtraInputs, seq_length: torch.LongTensor) -> ExtraInputs:
+ """Slice outputs."""
+ pass
+
+ @abstractmethod
+ def make_stopping_criteria(self, seqs: 'SeqList') -> StoppingCriteria:
+ """Create stopping criteria."""
+ pass
+
+ @abstractmethod
+ def make_extra_inputs(self, seqs: 'SeqList') -> ExtraInputs:
+ """Create extra inputs."""
+ pass
+
+ @abstractmethod
+ def make_extra_outputs(self, extra_inputs: ExtraInputs) -> ExtraOutputs:
+ """Create extra outputs."""
+ pass
+
+ @abstractmethod
+ def update_inputs_for_next_step(self, model_inputs: 'ModelInputs', sampling_inputs: 'SamplingInputs',
+ next_token_ids: torch.Tensor, model_metas: Any, extra_inputs: ExtraInputs,
+ **kwargs):
+ """Step next inputs."""
+ pass
+
+ @abstractmethod
+ def post_sampling(self, inputs: 'ModelInputs', logits: torch.Tensor, next_token_ids: torch.LongTensor,
+ extra_inputs: ExtraInputs):
+ """Post sampling."""
+ pass
diff --git a/lmdeploy/pytorch/strategies/base/model_inputs.py b/lmdeploy/pytorch/strategies/base/model_inputs.py
new file mode 100644
index 0000000000..f27134077d
--- /dev/null
+++ b/lmdeploy/pytorch/strategies/base/model_inputs.py
@@ -0,0 +1,54 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+from abc import ABC, abstractmethod
+
+import torch
+from torch.profiler import record_function
+
+from lmdeploy.pytorch.model_inputs import ModelInputs
+
+
+@record_function('make_dummy_input')
+def make_dummy_inputs(batch_size: int,
+ max_q_seqlen: int,
+ is_decoding: bool,
+ device: str = 'cpu',
+ dummy_block_id: int = 0,
+ vocab_size: int = 1):
+ """Make dummy inputs global implement."""
+ num_tokens = batch_size * max_q_seqlen
+ max_kv_seqlen = max_q_seqlen
+ input_ids = torch.randint(0, vocab_size, (
+ 1,
+ num_tokens,
+ ), dtype=torch.long, device=device)
+ seq_length = torch.full((batch_size, ), max_q_seqlen, dtype=torch.long, device=device)
+ history_lengths = torch.zeros((batch_size, ), dtype=torch.long, device=device)
+ block_offsets = torch.full((batch_size, 1), dummy_block_id, dtype=torch.long, device=device)
+ num_ignored_history = torch.zeros((batch_size, ), dtype=torch.long, device=device)
+ local_adapter_ids = torch.zeros((batch_size, ), dtype=torch.long, device=device)
+
+ return ModelInputs(
+ input_ids=input_ids,
+ seq_length=seq_length,
+ history_lengths=history_lengths,
+ block_offsets=block_offsets,
+ is_decoding=is_decoding,
+ num_ignored_history=num_ignored_history,
+ max_q_seqlen=max_q_seqlen,
+ max_kv_seqlen=max_kv_seqlen,
+ sum_kv_seqlen=batch_size,
+ local_adapter_ids=local_adapter_ids,
+ )
+
+
+class ModelInputsStrategy(ABC):
+
+ @abstractmethod
+ def make_dummy(self,
+ batch_size: int,
+ is_decoding: bool,
+ device: str = 'cpu',
+ dummy_block_id: int = 0,
+ vocab_size: int = 1) -> ModelInputs:
+ """Create dummy model inputs."""
+ pass
diff --git a/lmdeploy/pytorch/strategies/base/sampling.py b/lmdeploy/pytorch/strategies/base/sampling.py
new file mode 100644
index 0000000000..172454157b
--- /dev/null
+++ b/lmdeploy/pytorch/strategies/base/sampling.py
@@ -0,0 +1,17 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+from abc import ABC, abstractmethod
+from typing import List
+
+from lmdeploy.pytorch.engine.logits_process import SamplingInputs
+from lmdeploy.pytorch.messages import SchedulerSequence
+
+SeqList = List[SchedulerSequence]
+
+
+class SamplingStrategy(ABC):
+ """Base class for sampling strategies."""
+
+ @abstractmethod
+ def make_sampling_inputs(self, seqs: SeqList) -> SamplingInputs:
+ """Create sampling inputs from the sequences."""
+ pass
diff --git a/lmdeploy/pytorch/strategies/base/sequence.py b/lmdeploy/pytorch/strategies/base/sequence.py
new file mode 100644
index 0000000000..408a3cc15e
--- /dev/null
+++ b/lmdeploy/pytorch/strategies/base/sequence.py
@@ -0,0 +1,30 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+from abc import ABC, abstractmethod
+from typing import TYPE_CHECKING, List, Optional
+
+from lmdeploy.pytorch.disagg.conn.protocol import MigrationRequest
+
+if TYPE_CHECKING:
+ from lmdeploy.pytorch.engine.model_agent import BatchedOutputs
+ from lmdeploy.pytorch.messages import SamplingParam, SchedulerSequence, SchedulerSession
+ SeqList = List[SchedulerSequence]
+
+
+class SequenceStrategy(ABC):
+
+ @abstractmethod
+ def make_sequence(self,
+ seq_id: int,
+ session: 'SchedulerSession',
+ sampling_param: 'SamplingParam' = None,
+ adapter_name: str = None,
+ migration_request: Optional[MigrationRequest] = None,
+ resp_cache: bool = False,
+ preserve_cache: bool = False) -> 'SchedulerSequence':
+ """Make sequence."""
+ pass
+
+ @abstractmethod
+ def update_running(self, running: 'SeqList', batched_outputs: 'BatchedOutputs', is_decoding: bool) -> None:
+ """Update running sequences."""
+ pass
diff --git a/lmdeploy/pytorch/strategies/dllm/__init__.py b/lmdeploy/pytorch/strategies/dllm/__init__.py
new file mode 100644
index 0000000000..dc0395a017
--- /dev/null
+++ b/lmdeploy/pytorch/strategies/dllm/__init__.py
@@ -0,0 +1,84 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+from typing import TYPE_CHECKING
+
+from lmdeploy.pytorch.config import DLLMConfig, ModelConfig
+from lmdeploy.pytorch.strategies.base.sequence import SequenceStrategy
+from lmdeploy.utils import get_logger
+
+if TYPE_CHECKING:
+ from lmdeploy.pytorch.strategies.base.cudagraph import CudagraphStrategy
+ from lmdeploy.pytorch.strategies.base.model_inputs import ModelInputsStrategy
+ from lmdeploy.pytorch.strategies.base.sampling import SamplingStrategy
+ from lmdeploy.pytorch.strategies.base.model_agent import ModelAgentStrategy
+ from lmdeploy.pytorch.strategies.base.engine import EngineStrategy
+ from lmdeploy.pytorch.config import CacheConfig, SchedulerConfig
+
+from ..base import StrategyFactoryBase
+
+logger = get_logger('lmdeploy')
+
+
+class DLLMStrategyFactory(StrategyFactoryBase):
+
+ def __init__(self, model_config: ModelConfig, dllm_config: DLLMConfig):
+ """config."""
+ self.model_config = model_config
+ self.dllm_config = dllm_config
+
+ # update dllm_block_length
+ self.dllm_block_length = self._update_dllm_block_length()
+
+ def _update_dllm_block_length(self):
+ """Update dllm_block_length."""
+ if self.dllm_config.block_length is None:
+ dllm_block_length = self.model_config.dllm_block_length
+ if dllm_block_length is None:
+ dllm_block_length = 4
+ logger.warning('Model does not provide dllm_block_length. '
+ f'Set dllm_block_length={dllm_block_length} as default.')
+ else:
+ dllm_block_length = self.dllm_config.block_length
+
+ assert dllm_block_length is not None, 'dllm_block_length should be set in model_config or dllm_config'
+
+ self.dllm_config.block_length = dllm_block_length
+ self.model_config.dllm_block_length = dllm_block_length
+
+ if self.dllm_config.denoising_steps is None:
+ self.dllm_config.denoising_steps = dllm_block_length
+ return dllm_block_length
+
+ def build_cudagraph_strategy(self) -> 'CudagraphStrategy':
+ """Build cudagraph strategy."""
+ from .cudagraph import DLLMCudagraphStrategy
+ return DLLMCudagraphStrategy(block_size=self.dllm_block_length)
+
+ def build_sampling_strategy(self) -> 'SamplingStrategy':
+ """Build sampling strategy."""
+ from .sampling import DLLMSamplingStrategy
+ pad_token_id = self.model_config.bos_token_id
+ pad_token_id = 0 if pad_token_id is None else pad_token_id
+ return DLLMSamplingStrategy(pad_token_id, self.dllm_block_length)
+
+ def build_model_inputs_strategy(self) -> 'ModelInputsStrategy':
+ """Build model inputs strategy."""
+ from .model_inputs import DLLMModelInputsStrategy
+ return DLLMModelInputsStrategy(block_size=self.dllm_block_length)
+
+ def build_model_agent_strategy(self) -> 'ModelAgentStrategy':
+ """Build model agent strategy."""
+ from .model_agent import DLLMModelAgentStrategy
+ return DLLMModelAgentStrategy(dllm_config=self.dllm_config, dllm_mask_token=self.model_config.dllm_mask_token)
+
+ def build_engine_strategy(self, cache_config: 'CacheConfig',
+ scheduler_config: 'SchedulerConfig') -> 'EngineStrategy':
+ """Build engine strategy."""
+ from .engine import DLLMEngineStrategy
+ return DLLMEngineStrategy(cache_config=cache_config,
+ scheduler_config=scheduler_config,
+ dllm_block_length=self.dllm_block_length)
+
+ def build_sequence_strategy(self) -> SequenceStrategy:
+ from .sequence import DLLMSequenceStrategy
+ return DLLMSequenceStrategy(block_size=self.dllm_block_length,
+ dllm_mask_token=self.model_config.dllm_mask_token)
diff --git a/lmdeploy/pytorch/strategies/dllm/cudagraph.py b/lmdeploy/pytorch/strategies/dllm/cudagraph.py
new file mode 100644
index 0000000000..2e388b22de
--- /dev/null
+++ b/lmdeploy/pytorch/strategies/dllm/cudagraph.py
@@ -0,0 +1,13 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+from ..base.cudagraph import CudagraphStrategy
+
+
+class DLLMCudagraphStrategy(CudagraphStrategy):
+
+ def __init__(self, block_size: int) -> None:
+ super().__init__()
+ self.block_size = block_size
+
+ def get_max_tokens(self, batch_size: int) -> int:
+ """Get max tokens."""
+ return batch_size * self.block_size
diff --git a/lmdeploy/pytorch/strategies/dllm/engine.py b/lmdeploy/pytorch/strategies/dllm/engine.py
new file mode 100644
index 0000000000..32244a7abc
--- /dev/null
+++ b/lmdeploy/pytorch/strategies/dllm/engine.py
@@ -0,0 +1,50 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+from functools import lru_cache
+
+from lmdeploy.pytorch.config import CacheConfig, SchedulerConfig
+from lmdeploy.utils import get_logger
+
+from ..base.engine import EngineStrategy
+
+logger = get_logger('lmdeploy')
+
+
+class DLLMEngineStrategy(EngineStrategy):
+ """DLLM Engine Strategy."""
+
+ def __init__(self, scheduler_config: SchedulerConfig, cache_config: CacheConfig, dllm_block_length: int) -> None:
+ self.scheduler_config = scheduler_config
+ self.cache_config = cache_config
+ self.dllm_block_length = dllm_block_length
+
+ self._check()
+
+ def _check(self):
+ """check."""
+ max_prefill_token_num = self.cache_config.max_prefill_token_num
+ max_batches = self.cache_config.max_batches
+ if self.dllm_block_length * max_batches > max_prefill_token_num:
+ logger.warning(f'dllm_block_length({self.dllm_block_length}) * max_batch_size ({max_batches}) '
+ f'> max_prefill_token_num ({max_prefill_token_num}). '
+ 'This may lead to OOM. Consider to reduce max_batch_size or dllm_block_length.')
+
+ @lru_cache(maxsize=2)
+ def get_prealloc_size(self, is_decoding: bool) -> int:
+ """Get prealloc_size."""
+ if not is_decoding:
+ return 0
+ block_size = self.cache_config.block_size
+ dllm_block_length = self.dllm_block_length
+ num_blocks = min(self.scheduler_config.prefill_interval // 2, block_size // dllm_block_length)
+ return num_blocks * dllm_block_length
+
+ @lru_cache(maxsize=2)
+ def get_num_loops(self, is_decoding: bool) -> int:
+ """Get num_loops."""
+ if not is_decoding:
+ return 1
+ block_size = self.cache_config.block_size
+ dllm_block_length = self.dllm_block_length
+ max_num_loops = block_size // dllm_block_length * 2
+ num_loops = min(self.scheduler_config.prefill_interval, max_num_loops)
+ return num_loops
diff --git a/lmdeploy/pytorch/strategies/dllm/model_agent.py b/lmdeploy/pytorch/strategies/dllm/model_agent.py
new file mode 100644
index 0000000000..a5104d4981
--- /dev/null
+++ b/lmdeploy/pytorch/strategies/dllm/model_agent.py
@@ -0,0 +1,218 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+from dataclasses import dataclass
+from typing import Any, List, Optional
+
+import numpy as np
+import torch
+from torch.profiler import record_function
+
+from lmdeploy.pytorch import consts
+from lmdeploy.pytorch.config import DLLMConfig
+from lmdeploy.pytorch.engine.logits_process import SamplingInputs
+from lmdeploy.pytorch.messages import SchedulerSequence
+from lmdeploy.pytorch.model_inputs import ModelInputs
+
+from ..base.model_agent import ExtraInputs, ExtraOutputs, ModelAgentStrategy, StoppingCriteria
+from .unmasking import UnmaskingProcessor
+
+SeqList = List[SchedulerSequence]
+
+
+@dataclass
+class DLLMExtraInputs(ExtraInputs):
+ """DLLM extra inputs."""
+ dllm_mask: torch.Tensor
+
+
+@dataclass
+class DLLMExtraOutputs(ExtraOutputs):
+ """Ar extra outputs."""
+ dllm_mask: torch.Tensor
+
+
+def _check_stopwords_dllm(token_ids: torch.Tensor, stop_words: torch.Tensor, is_unmasked: torch.Tensor,
+ stopped: torch.Tensor, stop_pos: torch.Tensor, num_appendable_ids: torch.Tensor,
+ output_start_pos: torch.Tensor, inputs: ModelInputs):
+ num_tokens = token_ids.size(0)
+ batch_size = num_appendable_ids.size(0)
+ block_size = num_tokens // batch_size
+
+ # blocks might contain stop words in prev-round chat
+ # these stop words should be ignored
+ kv_seqlens = inputs.history_lengths + inputs.seq_length
+ ignore_pos = (output_start_pos - (kv_seqlens - block_size)).clamp_min(0)
+ ignore_range = torch.arange(0, block_size, dtype=ignore_pos.dtype, device=ignore_pos.device)
+ ignore_mask = (ignore_range[None, :] < ignore_pos[:, None]).flatten()
+ token_ids = token_ids.clone()
+ token_ids[ignore_mask] = -1
+
+ # find stop words
+ sw_stopped = (token_ids[:, None] == stop_words).any(1)
+ sw_stopped = sw_stopped.view(batch_size, block_size)
+ sw_stop_pos = sw_stopped.int().argmax(1)
+
+ stop_pos = torch.where(stopped, stop_pos, sw_stop_pos)
+ sw_stopped = sw_stopped.any(dim=1)
+ sw_stopped = sw_stopped & is_unmasked
+ stopped = stopped | sw_stopped
+
+ # update num_appendable_ids
+ one_ids = torch.clamp_max(num_appendable_ids, 0)
+ num_appendable_ids = torch.where(sw_stopped, one_ids, num_appendable_ids)
+
+ return stopped, stop_pos, num_appendable_ids
+
+
+@dataclass
+class DLLMStoppingCriteria(StoppingCriteria):
+ num_appendable_ids: torch.Tensor
+ output_start_pos: torch.Tensor
+
+ @record_function('stopping_criteria')
+ def step(self,
+ token_ids: torch.Tensor,
+ stop_words: torch.Tensor,
+ inputs: Optional[ModelInputs] = None,
+ extra_inputs: Optional[DLLMExtraInputs] = None):
+ """Check whether to stop generation."""
+ num_appendable_ids = self.num_appendable_ids
+ output_start_pos = self.output_start_pos
+ num_tokens = token_ids.size(0)
+ batch_size = num_appendable_ids.size(0)
+ block_size = num_tokens // batch_size
+
+ dllm_mask = extra_inputs.dllm_mask
+ dllm_mask = dllm_mask.view(batch_size, block_size)
+ is_unmasked = (dllm_mask == consts.DLLM_UNMASKED).all(dim=1)
+
+ # check stop by num_new_tokens
+ num_appendable_ids -= is_unmasked * block_size
+ stopped = num_appendable_ids <= 0
+ stop_pos = block_size - 1 + num_appendable_ids
+
+ # check stop words
+ if stop_words is not None:
+ stopped, stop_pos, num_appendable_ids = _check_stopwords_dllm(token_ids,
+ stop_words,
+ is_unmasked,
+ stopped,
+ stop_pos,
+ num_appendable_ids,
+ output_start_pos=output_start_pos,
+ inputs=inputs)
+
+ new_stopping = DLLMStoppingCriteria(num_appendable_ids=num_appendable_ids, output_start_pos=output_start_pos)
+ return stopped, stop_pos, new_stopping
+
+
+class DLLMModelAgentStrategy(ModelAgentStrategy):
+
+ def __init__(self, dllm_config: DLLMConfig, dllm_mask_token: int):
+ block_size = dllm_config.block_length
+ self.block_size = block_size
+ self.dllm_mask_token = dllm_mask_token
+
+ self.unmasking_processor = UnmaskingProcessor(dllm_config=dllm_config)
+
+ def _update_dllm(self, next_token_ids: torch.Tensor, dllm_mask: torch.Tensor, seqlens: torch.Tensor):
+ """Update token_ids and dllm_mask."""
+ dllm_mask_token = self.dllm_mask_token
+ dllm_block_length = self.block_size
+
+ # reshape to (batch, dllm_block_length)
+ next_token_ids = next_token_ids.view(-1, dllm_block_length).clone()
+ dllm_mask = dllm_mask.view(-1, dllm_block_length).clone()
+
+ # flags
+ is_cached = (dllm_mask == consts.DLLM_CACHED).all(dim=1)
+
+ is_masked = (dllm_mask == consts.DLLM_MASKED)
+ next_token_ids[is_cached[:, None] | is_masked] = dllm_mask_token
+ dllm_mask[is_cached] = consts.DLLM_MASKED
+ seqlens = torch.where(is_cached.view(-1), seqlens, seqlens.new_zeros((1, )))
+
+ return next_token_ids.flatten(), dllm_mask.flatten(), seqlens
+
+ def slice_outputs(self, inputs: torch.Tensor, seq_length: torch.LongTensor) -> torch.Tensor:
+ """Slice outputs."""
+ block_length = self.block_size
+ # batch size = 1
+ if len(seq_length) == 1:
+ return inputs[-block_length:]
+
+ if len(seq_length) * block_length == inputs.size(0):
+ return inputs
+ last_idx = seq_length.cumsum(0)
+ block_range = torch.arange(-block_length, 0, device=last_idx.device)
+ index = (last_idx[:, None] + block_range[None, :]).flatten()
+ inputs = inputs[index]
+ return inputs
+
+ def slice_extra_inputs(self, extra_inputs: DLLMExtraInputs, seq_length: torch.LongTensor) -> DLLMExtraInputs:
+ """Slice outputs."""
+ dllm_mask = self.slice_outputs(extra_inputs.dllm_mask, seq_length)
+ return DLLMExtraInputs(dllm_mask=dllm_mask)
+
+ def _step_sampling_inputs(self, sampling_inputs: SamplingInputs, next_token_ids: torch.Tensor,
+ dllm_mask: torch.Tensor, **kwargs):
+ """step."""
+ from lmdeploy.pytorch import consts
+ dllm_block_size = self.block_size
+ DLLM_UNMASKED = consts.DLLM_UNMASKED
+ is_unmasked = (dllm_mask == DLLM_UNMASKED).view(-1, dllm_block_size).all(dim=1, keepdim=True)
+ num_ignore_eos = sampling_inputs.num_ignore_eos.view(-1, dllm_block_size)
+ num_ignore_eos = torch.where(is_unmasked, num_ignore_eos - dllm_block_size, num_ignore_eos)
+ sampling_inputs.num_ignore_eos = num_ignore_eos.flatten()
+ return sampling_inputs
+
+ def make_stopping_criteria(self, seqs: SeqList) -> DLLMStoppingCriteria:
+ """Create stopping criteria."""
+ # num_appendable
+ num_appendable = [seq.sampling_param.max_new_tokens - seq.num_new_tokens for seq in seqs]
+ num_appendable = torch.tensor(num_appendable)
+ block_size = self.block_size
+ remain = [seq.num_valid_ids % block_size for seq in seqs]
+ num_appendable += torch.tensor(remain)
+
+ # output_start_pos
+ pos = [seq.output_start_pos for seq in seqs]
+ output_start_pos = torch.tensor(pos)
+
+ return DLLMStoppingCriteria(num_appendable_ids=num_appendable, output_start_pos=output_start_pos)
+
+ def make_extra_inputs(self, seqs: 'SeqList') -> ExtraInputs:
+ """Create extra inputs."""
+ dllm_masks = [seq.dllm_mask for seq in seqs]
+ dllm_masks = torch.as_tensor(np.concatenate(dllm_masks))
+ return DLLMExtraInputs(dllm_mask=dllm_masks)
+
+ def make_extra_outputs(self, extra_inputs: DLLMExtraInputs) -> DLLMExtraOutputs:
+ """Create extra outputs."""
+ dllm_mask = extra_inputs.dllm_mask
+ return DLLMExtraOutputs(dllm_mask=dllm_mask)
+
+ def update_inputs_for_next_step(self, model_inputs: 'ModelInputs', sampling_inputs: 'SamplingInputs',
+ next_token_ids: torch.Tensor, model_metas: Any, extra_inputs: DLLMExtraInputs,
+ **kwargs):
+ """Step next inputs."""
+ model_inputs.model_metas = model_metas
+ dllm_mask = extra_inputs.dllm_mask
+
+ next_token_ids, dllm_mask, step_seqlens = self._update_dllm(next_token_ids, dllm_mask, model_inputs.seq_length)
+ model_inputs.step(next_token_ids, step_seqlens)
+ self._step_sampling_inputs(sampling_inputs, next_token_ids, dllm_mask=dllm_mask)
+
+ extra_inputs = DLLMExtraInputs(dllm_mask=dllm_mask)
+ return model_inputs, extra_inputs
+
+ def post_sampling(self, inputs: 'ModelInputs', logits: torch.Tensor, next_token_ids: torch.LongTensor,
+ extra_inputs: DLLMExtraInputs):
+ """Post sampling."""
+ dllm_mask = extra_inputs.dllm_mask
+ input_ids = inputs.input_ids
+ input_ids = self.slice_outputs(input_ids.flatten(), inputs.seq_length)
+
+ dllm_mask, next_token_ids = self.unmasking_processor(logits, input_ids, next_token_ids, dllm_mask)
+
+ extra_inputs.dllm_mask = dllm_mask
+ return next_token_ids, extra_inputs
diff --git a/lmdeploy/pytorch/strategies/dllm/model_inputs.py b/lmdeploy/pytorch/strategies/dllm/model_inputs.py
new file mode 100644
index 0000000000..f05ba415f2
--- /dev/null
+++ b/lmdeploy/pytorch/strategies/dllm/model_inputs.py
@@ -0,0 +1,24 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+from lmdeploy.pytorch.model_inputs import ModelInputs
+
+from ..base.model_inputs import ModelInputsStrategy, make_dummy_inputs
+
+
+class DLLMModelInputsStrategy(ModelInputsStrategy):
+
+ def __init__(self, block_size: int):
+ self.block_size = block_size
+
+ def make_dummy(self,
+ batch_size: int,
+ is_decoding: bool,
+ device: str = 'cpu',
+ dummy_block_id: int = 0,
+ vocab_size: int = 1) -> ModelInputs:
+ """Create dummy model inputs."""
+ return make_dummy_inputs(batch_size,
+ max_q_seqlen=self.block_size,
+ is_decoding=is_decoding,
+ device=device,
+ dummy_block_id=dummy_block_id,
+ vocab_size=vocab_size)
diff --git a/lmdeploy/pytorch/strategies/dllm/sampling.py b/lmdeploy/pytorch/strategies/dllm/sampling.py
new file mode 100644
index 0000000000..2ad5d5ecd7
--- /dev/null
+++ b/lmdeploy/pytorch/strategies/dllm/sampling.py
@@ -0,0 +1,57 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+from typing import List
+
+from lmdeploy.pytorch.engine.logits_process import SamplingInputs
+from lmdeploy.pytorch.messages import SchedulerSequence
+
+from ..ar.sampling import ARSamplingStrategy
+
+SeqList = List[SchedulerSequence]
+
+
+class DLLMSamplingStrategy(ARSamplingStrategy):
+ """Sampling strategy for autoregressive models."""
+
+ def __init__(self, pad_token_id: int, dllm_block_length: int) -> None:
+ super().__init__(pad_token_id)
+ self.dllm_block_length = dllm_block_length
+
+ def make_sampling_inputs(self, seqs: SeqList) -> SamplingInputs:
+ """Create sampling inputs from the sequences."""
+ out = super().make_sampling_inputs(seqs)
+ dllm_block_length = self.dllm_block_length
+
+ # repeat tensor
+ update_attr_names = [
+ 'temperature',
+ 'bad_words',
+ 'bad_mask',
+ 'stop_words',
+ 'stop_mask',
+ 'repetition_penalty',
+ 'top_k',
+ 'top_p',
+ 'min_p',
+ 'random_seeds',
+ 'random_offsets',
+ 'all_ids',
+ 'guided_input_ids',
+ 'num_ignore_eos',
+ ]
+ for name in update_attr_names:
+ attr = getattr(out, name)
+ if attr is None:
+ continue
+ repeats = (dllm_block_length, ) + (1, ) * (attr.dim())
+ attr = attr[None].repeat(*repeats).flatten(0, 1)
+ setattr(out, name, attr)
+
+ if len(out.response_formats) > 0:
+ new_resp_formats = []
+ for resp in out.response_formats:
+ new_resp_formats += [resp] * dllm_block_length
+ out.response_formats = tuple(new_resp_formats)
+
+ out.batch_size *= dllm_block_length
+
+ return out
diff --git a/lmdeploy/pytorch/strategies/dllm/sequence.py b/lmdeploy/pytorch/strategies/dllm/sequence.py
new file mode 100644
index 0000000000..ab004a2b63
--- /dev/null
+++ b/lmdeploy/pytorch/strategies/dllm/sequence.py
@@ -0,0 +1,248 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import time
+from dataclasses import dataclass, field
+from typing import Any, Dict, List, Optional
+
+import numpy as np
+from torch import Tensor
+
+from lmdeploy.pytorch import consts
+from lmdeploy.pytorch.disagg.conn.protocol import MigrationRequest
+from lmdeploy.pytorch.engine.model_agent import BatchedOutputs
+from lmdeploy.pytorch.messages import (HistoryTokenIds, InputEmbeddings, MessageStatus, MultiModalInputs, SamplingParam,
+ SchedulerSession, UpdateTokenMode, _to_ndarray)
+
+from ..ar.sequence import SchedulerSequenceDefault
+from ..base.sequence import SequenceStrategy
+
+SeqList = List['SchedulerSequenceDLLM']
+
+DLLM_MASKED = consts.DLLM_MASKED
+DLLM_UNMASKED = consts.DLLM_UNMASKED
+DLLM_CACHED = consts.DLLM_CACHED
+DLLM_MASK_DTYPE = np.uint8
+
+
+class HistoryDLLMMask(HistoryTokenIds):
+
+ def __init__(self, token_ids: np.ndarray = None, dtype: np.dtype = DLLM_MASK_DTYPE):
+ super().__init__(token_ids=token_ids, dtype=dtype)
+
+
+@dataclass
+class SchedulerSequenceDLLM(SchedulerSequenceDefault):
+
+ # For dllm
+ history_dllm_mask: HistoryDLLMMask = field(default_factory=HistoryDLLMMask)
+
+ def __post_init__(self):
+ """Post init."""
+ super().__post_init__()
+ self._num_valid_ids: int = len(self.history_cache)
+ self._strategy: DLLMSequenceStrategy = self._seq_meta.strategy
+
+ @property
+ def dllm_mask(self):
+ start = self.num_history_ids
+ end = start + self._num_token_ids
+ return self.history_dllm_mask._token_ids[start:end]
+
+ @property
+ def num_valid_ids(self):
+ return self._num_valid_ids
+
+ @property
+ def generated_ids(self) -> np.ndarray:
+ end = self.num_valid_ids
+ start = end - self.num_new_tokens
+ return self.history_cache._token_ids[start:end]
+
+ @property
+ def all_dllm_mask(self):
+ return self.history_dllm_mask._token_ids[:self.num_all_ids]
+
+ @property
+ def dllm_block_length(self):
+ return self._strategy.block_size
+
+ @property
+ def dllm_mask_token(self):
+ return self._strategy.dllm_mask_token
+
+ def set_stop_pos(self, pos: int):
+ dllm_block_length = self.dllm_block_length
+ val = dllm_block_length - pos - 1
+ self._num_valid_ids -= val
+ self.num_new_tokens -= val
+
+ def _update_token_ids_inputs(self, token_ids: np.ndarray, dllm_mask: np.ndarray):
+ """Append tokens."""
+ num_tokens = len(token_ids)
+ dllm_block_length = self.dllm_block_length
+ dllm_mask_token = self.dllm_mask_token
+ new_token_ids = [token_ids]
+ new_dllm_mask = [dllm_mask]
+
+ # add uncached tokens in token_ids
+ # for example, [cccc cccc uumm], the [uu] in last block is remain valid.
+ num_remain_valid = self.num_valid_ids - self.num_history_ids
+ if num_remain_valid != 0:
+ prev_token_ids = self.valid_ids[-num_remain_valid:]
+ prev_dllm_mask = np.full_like(prev_token_ids, DLLM_UNMASKED, dtype=DLLM_MASK_DTYPE)
+ new_token_ids = [prev_token_ids] + new_token_ids
+ new_dllm_mask = [prev_dllm_mask] + new_dllm_mask
+ self.history_cache.resize(self.num_history_ids)
+ self.history_dllm_mask.resize(self.num_history_ids)
+ num_tokens += num_remain_valid
+
+ # pad to align with dllm_block_length
+ num_pad = (-num_tokens) % dllm_block_length
+ if num_pad > 0:
+ pad_ids = np.full_like(token_ids, dllm_mask_token, shape=(num_pad, ))
+ pad_mask = np.full_like(dllm_mask, DLLM_MASKED, shape=(num_pad, ))
+ new_token_ids += [pad_ids]
+ new_dllm_mask += [pad_mask]
+
+ token_ids = np.concatenate(new_token_ids)
+ dllm_mask = np.concatenate(new_dllm_mask)
+
+ assert len(token_ids) % dllm_block_length == 0
+
+ self.history_cache.append(token_ids)
+ self.history_dllm_mask.append(dllm_mask)
+ self.output_start_pos = self._num_valid_ids + len(token_ids)
+ self._num_valid_ids = self.num_history_ids + num_tokens
+ self._num_token_ids = len(token_ids)
+ self.num_new_tokens = 0
+
+ def _update_token_ids_decode(self, token_ids: np.ndarray, dllm_mask: np.ndarray):
+ """Update token ids for decode."""
+ num_tokens = len(token_ids)
+ dllm_block_length = self.dllm_block_length
+ dllm_mask_token = self.dllm_mask_token
+ assert num_tokens % dllm_block_length == 0
+ num_history_ids = self.num_history_ids
+
+ token_ids[dllm_mask == DLLM_MASKED] = dllm_mask_token
+ self.history_cache[num_history_ids:] = token_ids
+ self.history_dllm_mask[num_history_ids:] = dllm_mask
+
+ # check if all blocks are cached
+ last_mask = dllm_mask[-dllm_block_length:]
+ is_unmasked = np.all(last_mask == DLLM_UNMASKED)
+ is_cached = np.all(last_mask == DLLM_CACHED)
+
+ if is_unmasked:
+ num_new = dllm_block_length - self._num_valid_ids % dllm_block_length
+ self._num_valid_ids += num_new
+ self.num_new_tokens += num_new
+
+ if is_cached:
+ # add new block
+ new_token_ids = np.full_like(token_ids, dllm_mask_token, shape=(dllm_block_length, ))
+ new_dllm_mask = np.full_like(dllm_mask, DLLM_MASKED, shape=(dllm_block_length, ))
+ self.history_cache.append(new_token_ids)
+ self.history_dllm_mask.append(new_dllm_mask)
+ self._num_history_ids += self._num_token_ids
+ self._num_token_ids = dllm_block_length
+
+ def _update_token_ids_prefill(self, token_ids: np.ndarray, dllm_mask: np.ndarray):
+ """Update token ids for prefill."""
+ dllm_block_length = self.dllm_block_length
+ num_history_ids = self.num_history_ids
+
+ # fill input cache
+ if self.num_token_ids > dllm_block_length:
+ end = self.num_token_ids - dllm_block_length
+ self.history_dllm_mask[num_history_ids:end] = DLLM_CACHED
+ self._num_history_ids += end
+ self._num_token_ids -= end
+
+ # decoding update
+ self._update_token_ids_decode(token_ids, dllm_mask)
+
+ def update_token_ids(self,
+ token_ids: Tensor,
+ multimodals: MultiModalInputs = None,
+ embeddings: List[InputEmbeddings] = None,
+ model_meta: Dict[str, Any] = None,
+ dllm_mask: Tensor = None,
+ mode: UpdateTokenMode = UpdateTokenMode.INPUTS,
+ **kwargs):
+ """Update token ids, old token ids will be added to history."""
+ # update history image nums
+ self._update_embeddings(embeddings)
+
+ # update multimodals
+ self._update_multimodals(multimodals)
+
+ self.arrive_time = time.perf_counter()
+
+ token_ids: np.ndarray = _to_ndarray(token_ids)
+ if dllm_mask is None:
+ dllm_mask = np.full_like(token_ids, DLLM_UNMASKED, dtype=DLLM_MASK_DTYPE)
+ dllm_mask: np.ndarray = _to_ndarray(dllm_mask)
+
+ if mode == UpdateTokenMode.INPUTS:
+ self._update_token_ids_inputs(token_ids, dllm_mask)
+ elif mode == UpdateTokenMode.PREFILL:
+ self._update_token_ids_prefill(token_ids, dllm_mask)
+ else:
+ self._update_token_ids_decode(token_ids, dllm_mask)
+
+ if model_meta is not None:
+ self.model_meta = model_meta
+
+
+class DLLMSequenceStrategy(SequenceStrategy):
+
+ def __init__(self, block_size: int, dllm_mask_token: int) -> None:
+ self.block_size = block_size
+ self.dllm_mask_token = dllm_mask_token
+
+ def make_sequence(self,
+ seq_id: int,
+ session: 'SchedulerSession',
+ sampling_param: 'SamplingParam' = None,
+ adapter_name: str = None,
+ migration_request: Optional[MigrationRequest] = None,
+ resp_cache: bool = False,
+ preserve_cache: bool = False) -> 'SchedulerSequenceDLLM':
+ """Make sequence."""
+ return SchedulerSequenceDLLM(seq_id=seq_id,
+ session=session,
+ sampling_param=sampling_param,
+ adapter_name=adapter_name,
+ migration_request=migration_request,
+ resp_cache=resp_cache,
+ preserve_cache=preserve_cache)
+
+ def update_running(self, running: SeqList, batched_outputs: BatchedOutputs, is_decoding: bool) -> None:
+ """Update running sequences."""
+ next_token_ids = batched_outputs.next_token_ids
+ stopped = batched_outputs.stopped
+ stopped = stopped.tolist()
+ model_metas = batched_outputs.model_metas
+ if model_metas is None:
+ model_metas = [None] * len(running)
+ dllm_mask = batched_outputs.extra_outputs.dllm_mask
+ stop_pos = batched_outputs.stop_pos
+
+ batch_size = len(running)
+ next_token_ids = next_token_ids.view(batch_size, -1).numpy()
+ dllm_mask = dllm_mask.view(batch_size, -1).numpy()
+ stop_pos = stop_pos.tolist()
+ update_mode = UpdateTokenMode.DECODE if is_decoding else UpdateTokenMode.PREFILL
+ for idx, token in enumerate(next_token_ids):
+ msg = running[idx]
+ stop = stopped[idx]
+ model_meta = model_metas[idx]
+ mask = dllm_mask[idx]
+ if msg.status != MessageStatus.LOCKED:
+ continue
+
+ # fill token
+ msg.update_token_ids(token, dllm_mask=mask, model_meta=model_meta, mode=update_mode)
+ if stop:
+ msg.set_stop_pos(stop_pos[idx])
+ msg.status = MessageStatus.TO_BE_MIGRATED if msg.preserve_cache else MessageStatus.STOPPED
diff --git a/lmdeploy/pytorch/strategies/dllm/unmasking.py b/lmdeploy/pytorch/strategies/dllm/unmasking.py
new file mode 100644
index 0000000000..7c24ac8d3f
--- /dev/null
+++ b/lmdeploy/pytorch/strategies/dllm/unmasking.py
@@ -0,0 +1,117 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import torch
+from torch.profiler import record_function
+
+from lmdeploy.pytorch import consts
+from lmdeploy.pytorch.config import DLLMConfig, UnmaskingStrategy
+
+DLLM_MASKED = consts.DLLM_MASKED
+DLLM_UNMASKED = consts.DLLM_UNMASKED
+DLLM_CACHED = consts.DLLM_CACHED
+
+
+class UnmaskingProcessor:
+
+ def __init__(self, dllm_config: DLLMConfig):
+ self.dllm_config = dllm_config
+
+ def _get_scores(self, logits: torch.Tensor, token_ids: torch.Tensor):
+ """Get scores."""
+ scores = logits.softmax(dim=-1)
+ scores = scores.gather(-1, token_ids.unsqueeze(-1)).flatten()
+ return scores
+
+ def _get_denoise_num(self):
+ """Get denoise num."""
+ block_size = self.dllm_config.block_length
+ denoising_steps = self.dllm_config.denoising_steps
+ if denoising_steps is None:
+ denoising_steps = block_size
+ num = block_size // self.dllm_config.denoising_steps
+ num = max(1, min(num, block_size))
+ return num
+
+ def low_confidence_static(self, logits: torch.Tensor, token_ids: torch.Tensor, dllm_mask: torch.Tensor):
+ """static."""
+ block_size = self.dllm_config.block_length
+ topk = self._get_denoise_num()
+ scores = self._get_scores(logits, token_ids)
+ is_masked = dllm_mask == DLLM_MASKED
+ scores = torch.where(is_masked, scores, scores.new_zeros((1, )))
+
+ scores = scores.view(-1, block_size)
+ dllm_mask = dllm_mask.view(-1, block_size)
+ _, indices = scores.topk(topk, dim=-1)
+ dllm_unmasked = dllm_mask.scatter(-1, indices, DLLM_UNMASKED)
+
+ is_masked = is_masked.view_as(dllm_mask)
+ dllm_mask = torch.where(is_masked, dllm_unmasked, dllm_mask)
+ return dllm_mask.flatten()
+
+ def low_confidence_dynamic(self, logits: torch.Tensor, token_ids: torch.Tensor, dllm_mask: torch.Tensor):
+ """dynamic."""
+ block_size = self.dllm_config.block_length
+ threshold = self.dllm_config.confidence_threshold
+ scores = self._get_scores(logits, token_ids)
+ is_masked = dllm_mask == DLLM_MASKED
+ scores = torch.where(is_masked, scores, scores.new_zeros((1, )))
+
+ scores = scores.view(-1, block_size)
+ dllm_mask = dllm_mask.view(-1, block_size)
+ _, indices = scores.topk(1, dim=-1)
+ scores = scores.scatter(-1, indices, threshold)
+
+ is_masked = is_masked.view_as(dllm_mask)
+ is_masked &= scores >= threshold
+ dllm_mask[is_masked] = DLLM_UNMASKED
+ return dllm_mask.flatten()
+
+ def sequential(self, dllm_mask: torch.Tensor):
+ """sequential."""
+ block_size = self.dllm_config.block_length
+ denoise_num = self._get_denoise_num()
+ dllm_mask = dllm_mask.view(-1, block_size)
+ is_masked = dllm_mask == DLLM_MASKED
+
+ # get indices
+ indices = is_masked.int().argmax(dim=1)
+ ranges = torch.arange(0, denoise_num, device=indices.device, dtype=indices.dtype)
+ indices = indices[:, None] + ranges[None, :]
+ indices = indices % block_size
+
+ dllm_unmasked = dllm_mask.clone()
+ dllm_unmasked = dllm_unmasked.scatter(-1, indices, DLLM_UNMASKED)
+ dllm_mask = torch.where(is_masked, dllm_unmasked, dllm_mask)
+
+ return dllm_mask.flatten()
+
+ @record_function('unmasking')
+ def __call__(self, logits: torch.Tensor, input_ids: torch.Tensor, token_ids: torch.Tensor, dllm_mask: torch.Tensor):
+ """call."""
+ strategy = self.dllm_config.unmasking_strategy
+ if strategy is None:
+ return dllm_mask
+
+ # reshape to [num_blocks, block_size]
+ block_size = self.dllm_config.block_length
+ dllm_mask = dllm_mask.unflatten(0, (-1, block_size))
+
+ is_same = (dllm_mask == dllm_mask[:, :1]).all(dim=1)
+ first_mask = dllm_mask[:, 0]
+
+ # unmasked to cache
+ is_block_unmasked = is_same & (first_mask == DLLM_UNMASKED)
+ dllm_mask[is_block_unmasked] = DLLM_CACHED
+
+ dllm_mask = dllm_mask.flatten()
+ token_ids = torch.where(dllm_mask != DLLM_MASKED, input_ids, token_ids)
+ if strategy == UnmaskingStrategy.LOW_CONFIDENCE_STATIC:
+ dllm_mask = self.low_confidence_static(logits, token_ids, dllm_mask)
+ elif strategy == UnmaskingStrategy.LOW_CONFIDENCE_DYNAMIC:
+ dllm_mask = self.low_confidence_dynamic(logits, token_ids, dllm_mask)
+ elif strategy == UnmaskingStrategy.SEQUENTIAL:
+ dllm_mask = self.sequential(dllm_mask)
+ else:
+ raise RuntimeError(f'strategy {strategy} not supported.')
+
+ return dllm_mask, token_ids
diff --git a/lmdeploy/pytorch/tools/utils.py b/lmdeploy/pytorch/tools/utils.py
index 32b707b140..98d42772c0 100644
--- a/lmdeploy/pytorch/tools/utils.py
+++ b/lmdeploy/pytorch/tools/utils.py
@@ -132,6 +132,13 @@ def visualize_pipe_out(outputs, enable_meta: bool = True):
from lmdeploy.messages import Response
+ try:
+ from termcolor import colored
+ except ImportError:
+
+ def colored(text, color=None, on_color=None, attrs=None):
+ return text
+
if isinstance(outputs, Response):
outputs = [outputs]
try:
@@ -139,24 +146,49 @@ def visualize_pipe_out(outputs, enable_meta: bool = True):
except Exception:
term_size = 100
- def _lined_print(msg: str, line_format: str = '-', full_line: bool = False):
- print(msg)
- if full_line:
- columns = term_size
- else:
- columns = max(len(m) for m in msg.split('\n'))
- print(line_format * columns)
+ border_color = 'cyan'
+ meta_color = 'light_grey'
+ number_color = 'green'
+
+ def _print_title(title: str, color: str = border_color):
+ title_text = f' {title} '
+ print(colored(f'【{title_text}】', color, attrs=['bold']))
+
+ def _print_section(title: str, content: str, color: str = border_color):
+ """Simple title and content printing."""
+ _print_title(title, color)
+ print(content)
+
+ def _print_meta(out: Response):
+ """Enhanced meta information display."""
+ # Create a clean table-like format
+ finish_color = 'yellow' if out.finish_reason == 'stop' else 'red'
+ meta_content = [
+ f"{colored('• Input Tokens:', meta_color)} {colored(out.input_token_len, number_color)}",
+ f"{colored('• Generated Tokens:', meta_color)} {colored(out.generate_token_len, number_color)}",
+ f"{colored('• Finish Reason:', meta_color)} {colored(out.finish_reason, finish_color)}"
+ ]
+
+ lines = '\n'.join(meta_content)
+ lines += '\n'
+ _print_section('METADATA', lines, border_color)
+
+ # Main loop
+ print(colored('━' * term_size, border_color))
outputs: List[Response] = outputs
- term_line = '—' * term_size
- print(term_line)
for idx, out in enumerate(outputs):
- _lined_print(f'output[{idx}]', '=')
+ header = f'OUTPUT [{idx + 1}/{len(outputs)}]'
+ header_formatted = colored(f'✦ {header}', 'light_magenta', attrs=['bold'])
+ print(header_formatted)
+ print()
+
if enable_meta:
- _lined_print('meta', '-')
- _lined_print(
- f'input_token_len={out.input_token_len}\n'
- f'generate_token_len={out.generate_token_len}\n'
- f'finish_reason="{out.finish_reason}"', '—')
- _lined_print('text', '-')
- _lined_print(f'{out.text}', '—', full_line=True)
+ _print_meta(out)
+
+ _print_section('TEXT', out.text, border_color)
+
+ if idx < len(outputs) - 1: # Add separator when it's not the last output
+ print(colored('─' * (term_size), border_color, attrs=['dark']))
+ else:
+ print(colored('━' * term_size, border_color))
diff --git a/tests/pytorch/kernel/test_flash_attention.py b/tests/pytorch/kernel/test_flash_attention.py
index 2b6d17aa2f..5e1ad7cf1a 100644
--- a/tests/pytorch/kernel/test_flash_attention.py
+++ b/tests/pytorch/kernel/test_flash_attention.py
@@ -11,16 +11,18 @@ def _conti_input(data, q_seqlens):
def _make_bias(q_seqlens, history_lens, neg_val, causal):
+ batch_size = q_seqlens.shape[0]
kv_seqlens = q_seqlens + history_lens
max_seq_len = q_seqlens.max().item()
max_kv_len = kv_seqlens.max().item()
if causal:
- seq_ranges = [torch.arange(max_seq_len) for _ in q_seqlens]
- for r, l in zip(seq_ranges, q_seqlens):
- r[l:] = -max_kv_len
- seq_ranges = torch.stack(seq_ranges, dim=0).cuda()
- kv_ranges = [torch.arange(max_kv_len) for _ in kv_seqlens]
- kv_ranges = torch.stack(kv_ranges, 0).cuda()
+ seq_ranges = torch.arange(max_seq_len).cuda()
+ seq_ranges = seq_ranges.repeat(batch_size, 1)
+ seq_ranges = torch.where(seq_ranges < q_seqlens[:, None], seq_ranges, -max_kv_len)
+
+ kv_ranges = torch.arange(max_kv_len).cuda()
+ kv_ranges = kv_ranges.repeat(batch_size, 1)
+
mask = (kv_ranges[:, None, :] - seq_ranges[:, :, None] > history_lens[:, None, None])
return mask.float() * neg_val
else:
@@ -31,6 +33,27 @@ def _make_bias(q_seqlens, history_lens, neg_val, causal):
return (~mask).float() * neg_val
+def _make_block_sparse_bias(q_seqlens: torch.Tensor, history_lens: torch.Tensor, neg_val: float,
+ block_sparse_size: int):
+ """Make block sparse bias."""
+ batch_size = q_seqlens.shape[0]
+ kv_seqlens = q_seqlens + history_lens
+ max_seq_len = q_seqlens.max().item()
+ max_kv_len = kv_seqlens.max().item()
+
+ seq_ranges = torch.arange(max_seq_len).cuda()
+ seq_ranges = seq_ranges // block_sparse_size * block_sparse_size
+ seq_ranges = seq_ranges.repeat(batch_size, 1)
+ seq_ranges = torch.where(seq_ranges < q_seqlens[:, None], seq_ranges, -max_kv_len)
+
+ kv_ranges = torch.arange(max_kv_len).cuda()
+ kv_ranges = kv_ranges // block_sparse_size * block_sparse_size
+ kv_ranges = kv_ranges.repeat(batch_size, 1)
+
+ mask = (kv_ranges[:, None, :] - seq_ranges[:, :, None] > history_lens[:, None, None])
+ return mask.float() * neg_val
+
+
def _naive_attention(batched_q, batched_kv, bias, sinks=None):
batched_k, batched_v = batched_kv
@@ -283,3 +306,42 @@ def test_sinks(self, conti_q, conti_kv, q_start_loc, q_seqlens, kv_start_loc, kv
sinks=sinks,
causal=causal)
torch.testing.assert_close(out, conti_sink_gt, atol=1e-3, rtol=1e-5)
+
+ # block sparse attention
+ @pytest.fixture
+ def block_sparse_size(self):
+ yield 4
+
+ @pytest.fixture
+ def block_sparse_mask(self, q_seqlens, history_lens, block_sparse_size):
+ neg_val = -1e30
+ yield _make_block_sparse_bias(q_seqlens, history_lens, neg_val, block_sparse_size)
+
+ @pytest.fixture
+ def block_sparse_gt(self, batched_q, batched_kv, block_sparse_mask):
+ yield _naive_attention(batched_q, batched_kv, block_sparse_mask)
+
+ @pytest.mark.parametrize('head_dim_k', [32], indirect=True)
+ @pytest.mark.parametrize('head_dim_v', [32], indirect=True)
+ @pytest.mark.parametrize('num_heads_q', [8], indirect=True)
+ @pytest.mark.parametrize('num_heads_k', [2], indirect=True)
+ @pytest.mark.parametrize(['q_seqlens', 'history_lens'], [([16, 32], [64, 8])], indirect=True)
+ def test_block_sparse_attention(self, conti_q, conti_kv, q_start_loc, q_seqlens, kv_start_loc, kv_seqlens,
+ head_dim_v, block_sparse_size, block_sparse_gt):
+ from lmdeploy.pytorch.kernels.cuda.flashattention import flash_attention_fwd
+ max_seq_len = q_seqlens.max().item()
+
+ conti_k, conti_v = conti_kv
+ out = conti_q.new_empty(*conti_q.shape[:-1], head_dim_v)
+ flash_attention_fwd(conti_q,
+ conti_k,
+ conti_v,
+ out,
+ q_start_loc=q_start_loc,
+ q_seqlens=q_seqlens,
+ kv_start_loc=kv_start_loc,
+ kv_seqlens=kv_seqlens,
+ max_seqlen=max_seq_len,
+ block_sparse_size=block_sparse_size)
+ gt = _conti_input(block_sparse_gt, q_seqlens)
+ torch.testing.assert_close(out, gt, atol=1e-3, rtol=1e-5)
diff --git a/tests/pytorch/kernel/test_paged_attention.py b/tests/pytorch/kernel/test_paged_attention.py
index a6289fc626..f4268cacdc 100644
--- a/tests/pytorch/kernel/test_paged_attention.py
+++ b/tests/pytorch/kernel/test_paged_attention.py
@@ -10,20 +10,42 @@ def _conti_input(data, seq_lens):
return data
-def _make_bias(seq_lens, history_lens, neg_val):
- full_seq_lens = seq_lens + history_lens
- max_seq_len = seq_lens.max().item()
- max_full_len = full_seq_lens.max().item()
- seq_ranges = [torch.arange(max_seq_len) for _ in seq_lens]
- for r, l in zip(seq_ranges, seq_lens):
- r[l:] = -max_full_len
- seq_ranges = torch.stack(seq_ranges, dim=0).cuda()
- kv_ranges = [torch.arange(max_full_len) for _ in full_seq_lens]
- kv_ranges = torch.stack(kv_ranges, 0).cuda()
+def _make_bias(q_seqlens, history_lens, neg_val):
+ batch_size = q_seqlens.shape[0]
+ full_seq_lens = q_seqlens + history_lens
+ max_seq_len = q_seqlens.max().item()
+ max_kv_len = full_seq_lens.max().item()
+ seq_ranges = torch.arange(max_seq_len).cuda()
+ seq_ranges = seq_ranges.repeat(batch_size, 1)
+ seq_ranges = torch.where(seq_ranges < q_seqlens[:, None], seq_ranges, -max_kv_len)
+
+ kv_ranges = torch.arange(max_kv_len).cuda()
+ kv_ranges = kv_ranges.repeat(batch_size, 1)
mask = kv_ranges[:, None, :] - seq_ranges[:, :, None] > history_lens[:, None, None]
return mask.float() * neg_val
+def _make_block_sparse_bias(q_seqlens: torch.Tensor, history_lens: torch.Tensor, neg_val: float,
+ block_sparse_size: int):
+ """Make block sparse bias."""
+ batch_size = q_seqlens.shape[0]
+ kv_seqlens = q_seqlens + history_lens
+ max_seq_len = q_seqlens.max().item()
+ max_kv_len = kv_seqlens.max().item()
+
+ seq_ranges = torch.arange(max_seq_len).cuda()
+ seq_ranges = seq_ranges // block_sparse_size * block_sparse_size
+ seq_ranges = seq_ranges.repeat(batch_size, 1)
+ seq_ranges = torch.where(seq_ranges < q_seqlens[:, None], seq_ranges, -max_kv_len)
+
+ kv_ranges = torch.arange(max_kv_len).cuda()
+ kv_ranges = kv_ranges // block_sparse_size * block_sparse_size
+ kv_ranges = kv_ranges.repeat(batch_size, 1)
+
+ mask = (kv_ranges[:, None, :] - seq_ranges[:, :, None] > history_lens[:, None, None])
+ return mask.float() * neg_val
+
+
def _make_blocked_cache(batched_k,
batched_v,
seq_lens,
@@ -119,7 +141,7 @@ def _make_cu_seqlens(seqlens):
return output
-class TestPagedAttention:
+class TestPagedAttentionBase:
@pytest.fixture
def dtype(self):
@@ -154,18 +176,22 @@ def history_lens(self, request):
yield torch.tensor(request.param, device='cuda')
@pytest.fixture
- def seq_lens(self, history_lens):
- yield torch.ones_like(history_lens)
+ def seq_len(self):
+ yield 1
+
+ @pytest.fixture
+ def seq_lens(self, seq_len, history_lens):
+ yield torch.ones_like(history_lens) * seq_len
@pytest.fixture
- def kv_seqlens(self, history_lens):
- yield 1 + history_lens
+ def kv_seqlens(self, seq_lens, history_lens):
+ yield seq_lens + history_lens
@pytest.fixture
- def batched_q(self, kv_seqlens, num_heads_q, feat_dim, dtype):
+ def batched_q(self, seq_len, kv_seqlens, num_heads_q, feat_dim, dtype):
torch.manual_seed(123)
batch_size = len(kv_seqlens)
- inputs = torch.rand(batch_size, 1, num_heads_q, feat_dim, dtype=dtype, device='cuda')
+ inputs = torch.rand(batch_size, seq_len, num_heads_q, feat_dim, dtype=dtype, device='cuda')
yield inputs
@pytest.fixture
@@ -178,8 +204,7 @@ def batched_kv(self, kv_seqlens, num_heads_k, feat_dim, feat_dim_v, dtype):
yield k, v
@pytest.fixture
- def conti_q(self, kv_seqlens, batched_q):
- seq_lens = torch.ones_like(kv_seqlens)
+ def conti_q(self, seq_lens, batched_q):
yield _conti_input(batched_q, seq_lens)
@pytest.fixture
@@ -225,7 +250,10 @@ def gt(self, batched_q, batched_kv, mask):
def conti_gt(self, gt, seq_lens):
yield _conti_input(gt, seq_lens)
- @pytest.mark.parametrize('feat_dim', [48, 32], indirect=True)
+
+class TestPagedAttention(TestPagedAttentionBase):
+
+ @pytest.mark.parametrize('feat_dim', [32, 32], indirect=True)
@pytest.mark.parametrize('feat_dim_v', [32], indirect=True)
@pytest.mark.parametrize(['num_heads_q', 'num_heads_k'], [(128, 2), (8, 2), (2, 2)], indirect=True)
@pytest.mark.parametrize('history_lens', [(50, 40, 30, 20)], indirect=True)
@@ -288,7 +316,7 @@ def test_window_attention(self, conti_q, blocked_kv, block_offsets, history_lens
torch.testing.assert_close(out, window_gt, atol=1e-3, rtol=1e-5)
-class TestPagedAttentionSink(TestPagedAttention):
+class TestPagedAttentionSink(TestPagedAttentionBase):
@pytest.fixture
def sinks(self, num_heads_q, dtype):
@@ -308,7 +336,8 @@ def conti_sink_gt(self, sink_gt, seq_lens):
@pytest.mark.parametrize('history_lens', [(50, 40, 30, 20)], indirect=True)
@pytest.mark.parametrize('block_size', [16], indirect=True)
@pytest.mark.parametrize('layout', ['bshd'], indirect=True)
- def test_sink(self, conti_q, blocked_kv, block_offsets, history_lens, feat_dim_v, layout, sinks, conti_sink_gt):
+ def test_paged_attention(self, conti_q, blocked_kv, block_offsets, history_lens, feat_dim_v, layout, sinks,
+ conti_sink_gt):
from lmdeploy.pytorch.kernels.cuda import paged_attention_fwd
kv_seq_lens = 1 + history_lens
@@ -450,3 +479,46 @@ class TestPagedAttentionInt4(TestPagedAttentionInt8):
@pytest.fixture
def nbits(self):
yield 4
+
+
+class TestPagedAttentionBlockDecoding(TestPagedAttentionBase):
+
+ @pytest.fixture
+ def seq_len(self):
+ yield 4
+
+ @pytest.fixture
+ def mask(self, seq_lens, history_lens, seq_len):
+ neg_val = -1e30
+ yield _make_block_sparse_bias(seq_lens, history_lens, neg_val, seq_len)
+
+ @pytest.fixture
+ def gt(self, batched_q, batched_kv, mask):
+ yield _naive_attention(batched_q, batched_kv, mask)
+
+ @pytest.fixture
+ def conti_gt(self, gt, seq_lens):
+ yield _conti_input(gt, seq_lens)
+
+ @pytest.mark.parametrize('feat_dim', [48, 32], indirect=True)
+ @pytest.mark.parametrize('feat_dim_v', [32], indirect=True)
+ @pytest.mark.parametrize(['num_heads_q', 'num_heads_k'], [(128, 2), (8, 2), (2, 2)], indirect=True)
+ @pytest.mark.parametrize('history_lens', [(52, 40, 32, 20)], indirect=True)
+ @pytest.mark.parametrize('block_size', [16], indirect=True)
+ @pytest.mark.parametrize('layout', ['bshd'], indirect=True)
+ def test_paged_attention(self, conti_q, blocked_kv, block_offsets, seq_lens, history_lens, feat_dim_v, layout,
+ conti_gt):
+ from lmdeploy.pytorch.kernels.cuda import paged_attention_fwd
+ kv_seq_lens = seq_lens + history_lens
+
+ blocked_k, blocked_v = blocked_kv
+ out = conti_q.new_empty(*conti_q.shape[:-1], feat_dim_v)
+
+ paged_attention_fwd(conti_q,
+ blocked_k,
+ blocked_v,
+ out,
+ block_offsets=block_offsets,
+ kv_seqlens=kv_seq_lens,
+ kv_layout=layout)
+ torch.testing.assert_close(out, conti_gt, atol=1e-3, rtol=1e-5)
diff --git a/tests/pytorch/paging/test_block_manager.py b/tests/pytorch/paging/test_block_manager.py
index cdf45d5d6d..f74b6548cf 100644
--- a/tests/pytorch/paging/test_block_manager.py
+++ b/tests/pytorch/paging/test_block_manager.py
@@ -2,7 +2,7 @@
import pytest
import torch
-from lmdeploy.pytorch.messages import SchedulerSession
+from lmdeploy.pytorch.messages import SchedulerSession, SequenceManager, SequenceMeta
from lmdeploy.pytorch.paging.block_manager import DefaultBlockManager, WindowBlockManager
from lmdeploy.pytorch.paging.block_manager.base_block_manager import LogicalAllocator
@@ -89,8 +89,16 @@ def num_gpu_blocks(self):
def block_mgr(self, num_cpu_blocks, num_gpu_blocks):
yield DefaultBlockManager(num_cpu_blocks, num_gpu_blocks)
- def test_alloc(self, block_mgr, block_size, num_gpu_blocks):
- sess = SchedulerSession(0, block_size)
+ @pytest.fixture
+ def seq_manager(self, block_size):
+ from lmdeploy.pytorch.strategies.ar.sequence import ARSequenceStrategy
+ strategy = ARSequenceStrategy()
+ seq_meta = SequenceMeta(block_size, strategy=strategy)
+ yield SequenceManager(seq_meta)
+
+ def test_alloc(self, block_mgr, seq_manager, num_gpu_blocks):
+ sess = SchedulerSession(0, seq_manager)
+ block_size = sess.seq_meta.block_size
# test alloc
token_ids = torch.tensor([1])
@@ -113,9 +121,10 @@ def test_alloc(self, block_mgr, block_size, num_gpu_blocks):
msg = sess.add_sequence(token_ids)
assert not block_mgr.can_allocate(msg)
- def test_num_required_blocks(self, block_mgr, block_size, num_gpu_blocks):
+ def test_num_required_blocks(self, block_mgr, seq_manager, num_gpu_blocks):
from lmdeploy.pytorch.messages import InputEmbeddings
- sess = SchedulerSession(0, block_size)
+ sess = SchedulerSession(0, seq_manager)
+ block_size = sess.seq_meta.block_size
token_ids = torch.tensor([1])
msg = sess.add_sequence(token_ids)
@@ -133,8 +142,9 @@ def test_num_required_blocks(self, block_mgr, block_size, num_gpu_blocks):
num_required = block_mgr.num_required_blocks(msg)
assert num_required == 3
- def test_append_slot(self, block_mgr, block_size, num_gpu_blocks):
- sess = SchedulerSession(0, block_size)
+ def test_append_slot(self, block_mgr, seq_manager, num_gpu_blocks):
+ sess = SchedulerSession(0, seq_manager)
+ block_size = sess.seq_meta.block_size
# test append
token_ids = torch.tensor([1])
@@ -158,8 +168,9 @@ def test_append_slot(self, block_mgr, block_size, num_gpu_blocks):
assert len(block_table) == 2
assert block_mgr.get_num_free_gpu_blocks() == num_gpu_blocks - 2
- def test_swap(self, block_mgr, block_size, num_gpu_blocks):
- sess = SchedulerSession(0, block_size)
+ def test_swap(self, block_mgr, seq_manager, num_gpu_blocks):
+ sess = SchedulerSession(0, seq_manager)
+ block_size = sess.seq_meta.block_size
token_ids = torch.tensor([1] * (block_size + 1))
msg = sess.add_sequence(token_ids)
@@ -215,12 +226,20 @@ def num_cpu_blocks(self):
def num_gpu_blocks(self):
yield 4
+ @pytest.fixture
+ def seq_manager(self, block_size):
+ from lmdeploy.pytorch.strategies.ar.sequence import ARSequenceStrategy
+ strategy = ARSequenceStrategy()
+ seq_meta = SequenceMeta(block_size, strategy=strategy)
+ yield SequenceManager(seq_meta)
+
@pytest.fixture
def block_mgr(self, num_cpu_blocks, num_gpu_blocks, window_size):
yield WindowBlockManager(num_cpu_blocks, num_gpu_blocks, window_size)
- def test_alloc(self, block_mgr, block_size, num_gpu_blocks):
- sess = SchedulerSession(0, block_size)
+ def test_alloc(self, block_mgr, seq_manager, num_gpu_blocks):
+ sess = SchedulerSession(0, seq_manager)
+ block_size = sess.seq_meta.block_size
# test alloc
token_ids = torch.tensor([1])
@@ -243,8 +262,8 @@ def test_alloc(self, block_mgr, block_size, num_gpu_blocks):
msg = sess.add_sequence(token_ids)
assert not block_mgr.can_allocate(msg)
- def test_win_alloc(self, block_mgr, block_size, num_gpu_blocks, window_size):
- sess = SchedulerSession(0, block_size)
+ def test_win_alloc(self, block_mgr, seq_manager, num_gpu_blocks, window_size):
+ sess = SchedulerSession(0, seq_manager)
# 2 win block
token_ids = torch.tensor([1] * window_size)
diff --git a/tests/pytorch/paging/test_block_trie.py b/tests/pytorch/paging/test_block_trie.py
index 06829f4c75..7d20c96dab 100644
--- a/tests/pytorch/paging/test_block_trie.py
+++ b/tests/pytorch/paging/test_block_trie.py
@@ -2,7 +2,7 @@
import pytest
from lmdeploy.pytorch.config import CacheConfig
-from lmdeploy.pytorch.messages import SchedulerSession
+from lmdeploy.pytorch.messages import SchedulerSession, SequenceManager, SequenceMeta
from lmdeploy.pytorch.paging.block_manager import build_block_manager
from lmdeploy.pytorch.paging.block_trie import BlockTrie
@@ -37,9 +37,17 @@ def block_mgr(self, cache_config):
def block_trie(self, cache_config, block_mgr):
yield BlockTrie(cache_config, block_mgr)
- def test_allocate(self, block_trie, block_mgr, block_size):
+ @pytest.fixture
+ def seq_manager(self, block_size):
+ from lmdeploy.pytorch.strategies.ar.sequence import ARSequenceStrategy
+ strategy = ARSequenceStrategy()
+ seq_meta = SequenceMeta(block_size, strategy=strategy)
+ yield SequenceManager(seq_meta)
+
+ def test_allocate(self, block_trie, block_mgr, seq_manager):
allocator = block_trie.allocator
- sess = SchedulerSession(0, block_size)
+ sess = SchedulerSession(0, seq_manager)
+ block_size = sess.seq_meta.block_size
token_ids = ([1] * block_size + [2] * block_size)
token_ids += [3] * (block_size // 2)
seq = sess.add_sequence(token_ids)
@@ -75,9 +83,10 @@ def test_allocate(self, block_trie, block_mgr, block_size):
assert node in block_trie.leaves
assert len(block_trie.leaves) == 1
- def test_match(self, block_trie, block_mgr, block_size):
+ def test_match(self, block_trie, block_mgr, seq_manager):
allocator = block_trie.allocator
- sess = SchedulerSession(0, block_size)
+ sess = SchedulerSession(0, seq_manager)
+ block_size = sess.seq_meta.block_size
# initialize cache
token_ids = ([1] * block_size + [2] * block_size)
@@ -112,9 +121,10 @@ def test_match(self, block_trie, block_mgr, block_size):
ref_cnt = allocator.get_ref_count(logical_blocks.get_real_blocks())
assert np.array_equal(ref_cnt, [4, 3])
- def test_evict(self, block_trie, block_size, num_gpu_blocks):
+ def test_evict(self, block_trie, seq_manager, num_gpu_blocks):
block_mgr = block_trie.block_manager
- sess = SchedulerSession(0, block_size)
+ sess = SchedulerSession(0, seq_manager)
+ block_size = sess.seq_meta.block_size
token_ids = ([1] * block_size * (num_gpu_blocks - 1))
token_ids += [2] * (block_size // 2)
seq = sess.add_sequence(token_ids)
diff --git a/tests/pytorch/paging/test_scheduler.py b/tests/pytorch/paging/test_scheduler.py
index f14ab8249e..a0acf5f054 100644
--- a/tests/pytorch/paging/test_scheduler.py
+++ b/tests/pytorch/paging/test_scheduler.py
@@ -2,7 +2,7 @@
import torch
from lmdeploy.pytorch.config import CacheConfig, SchedulerConfig
-from lmdeploy.pytorch.messages import MessageStatus
+from lmdeploy.pytorch.messages import MessageStatus, SequenceMeta
from lmdeploy.pytorch.paging.scheduler import Scheduler
@@ -32,8 +32,14 @@ def scheduler_config(self):
yield SchedulerConfig(max_batches=4, max_session_len=128, max_request_output_len=64, eviction_type='recompute')
@pytest.fixture
- def scheduler(self, cache_config, scheduler_config):
- yield Scheduler(scheduler_config=scheduler_config, cache_config=cache_config)
+ def seq_meta(self, block_size):
+ from lmdeploy.pytorch.strategies.ar.sequence import ARSequenceStrategy
+ strategy = ARSequenceStrategy()
+ yield SequenceMeta(block_size, strategy=strategy)
+
+ @pytest.fixture
+ def scheduler(self, cache_config, scheduler_config, seq_meta):
+ yield Scheduler(scheduler_config=scheduler_config, cache_config=cache_config, seq_meta=seq_meta)
def test_schedule_base(self, scheduler, block_size, num_gpu_blocks):
block_manager = scheduler.block_manager