diff --git a/README.md b/README.md index 8bd7fcbc15..5138e0a455 100644 --- a/README.md +++ b/README.md @@ -150,6 +150,7 @@ LMDeploy is a toolkit for compressing, deploying, and serving LLM, developed by
  • Phi-3.5-MoE (16x3.8B)
  • Phi-4-mini (3.8B)
  • MiniCPM3 (4B)
  • +
  • SDAR (1.7B-30B)
  • gpt-oss (20B, 120B)
  • diff --git a/README_ja.md b/README_ja.md index 3537e84935..75d05390ad 100644 --- a/README_ja.md +++ b/README_ja.md @@ -137,6 +137,7 @@ LMDeploy TurboMindエンジンは卓越した推論能力を持ち、さまざ
  • Phi-3.5-MoE (16x3.8B)
  • Phi-4-mini (3.8B)
  • MiniCPM3 (4B)
  • +
  • SDAR (1.7B-30B)
  • diff --git a/README_zh-CN.md b/README_zh-CN.md index 7d7635dc70..179ffa466d 100644 --- a/README_zh-CN.md +++ b/README_zh-CN.md @@ -151,6 +151,7 @@ LMDeploy TurboMind 引擎拥有卓越的推理能力,在各种规模的模型
  • Phi-3.5-MoE (16x3.8B)
  • Phi-4-mini (3.8B)
  • MiniCPM3 (4B)
  • +
  • SDAR (1.7B-30B)
  • gpt-oss (20B, 120B)
  • diff --git a/benchmark/profile_throughput.py b/benchmark/profile_throughput.py index e2a4f724c6..6e4243cca6 100644 --- a/benchmark/profile_throughput.py +++ b/benchmark/profile_throughput.py @@ -307,6 +307,10 @@ def parse_args(): # pytorch engine args pt_group = parser.add_argument_group('PyTorch engine arguments') ArgumentHelper.eager_mode(pt_group) + ArgumentHelper.dllm_block_length(pt_group) + ArgumentHelper.dllm_unmasking_strategy(pt_group) + ArgumentHelper.dllm_denoising_steps(pt_group) + ArgumentHelper.dllm_confidence_threshold(pt_group) tp_act = ArgumentHelper.tp(pt_group) cache_count_act = ArgumentHelper.cache_max_entry_count(pt_group) @@ -363,6 +367,10 @@ def main(): quant_policy=args.quant_policy, dtype=args.dtype, distributed_executor_backend=args.distributed_executor_backend, + dllm_block_length=args.dllm_block_length, + dllm_unmasking_strategy=args.dllm_unmasking_strategy, + dllm_denoising_steps=args.dllm_denoising_steps, + dllm_confidence_threshold=args.dllm_confidence_threshold, ) if args.use_uvloop: diff --git a/docs/en/supported_models/supported_models.md b/docs/en/supported_models/supported_models.md index d4ffba4787..aa28854d8a 100644 --- a/docs/en/supported_models/supported_models.md +++ b/docs/en/supported_models/supported_models.md @@ -120,6 +120,7 @@ The following tables detail the models supported by LMDeploy's TurboMind engine | Phi-3.5-mini | 3.8B | LLM | Yes | Yes | No | - | - | | Phi-3.5-MoE | 16x3.8B | LLM | Yes | Yes | No | - | - | | Phi-3.5-vision | 4.2B | MLLM | Yes | Yes | No | - | - | +| SDAR | 1.7B-30B | LLM | Yes | Yes | No | - | - | ```{note} * [1] Currently Mono-InternVL does not support FP16 due to numerical instability. Please use BF16 instead. diff --git a/docs/zh_cn/supported_models/supported_models.md b/docs/zh_cn/supported_models/supported_models.md index d3d6946020..8e9e3fef20 100644 --- a/docs/zh_cn/supported_models/supported_models.md +++ b/docs/zh_cn/supported_models/supported_models.md @@ -120,6 +120,7 @@ | Phi-3.5-mini | 3.8B | LLM | Yes | Yes | No | - | - | | Phi-3.5-MoE | 16x3.8B | LLM | Yes | Yes | No | - | - | | Phi-3.5-vision | 4.2B | MLLM | Yes | Yes | No | - | - | +| SDAR | 1.7B-30B | LLM | Yes | Yes | No | - | - | ```{note} * [1] 目前,Mono-InternVL不支持FP16,因为数值不稳定。请改用BF16 diff --git a/lmdeploy/cli/cli.py b/lmdeploy/cli/cli.py index 916b91e527..d71198791f 100644 --- a/lmdeploy/cli/cli.py +++ b/lmdeploy/cli/cli.py @@ -55,6 +55,7 @@ def add_parser_chat(): ArgumentHelper.adapters(pt_group) ArgumentHelper.device(pt_group) ArgumentHelper.eager_mode(pt_group) + ArgumentHelper.dllm_block_length(pt_group) # common engine args dtype_act = ArgumentHelper.dtype(pt_group) tp_act = ArgumentHelper.tp(pt_group) diff --git a/lmdeploy/cli/serve.py b/lmdeploy/cli/serve.py index 7fac263001..6a9e9f2b13 100644 --- a/lmdeploy/cli/serve.py +++ b/lmdeploy/cli/serve.py @@ -92,6 +92,10 @@ def add_parser_api_server(): ArgumentHelper.eager_mode(pt_group) ArgumentHelper.disable_vision_encoder(pt_group) ArgumentHelper.logprobs_mode(pt_group) + ArgumentHelper.dllm_block_length(pt_group) + ArgumentHelper.dllm_unmasking_strategy(pt_group) + ArgumentHelper.dllm_denoising_steps(pt_group) + ArgumentHelper.dllm_confidence_threshold(pt_group) # common engine args dtype_act = ArgumentHelper.dtype(pt_group) @@ -219,6 +223,10 @@ def api_server(args): hf_overrides=args.hf_overrides, disable_vision_encoder=args.disable_vision_encoder, logprobs_mode=args.logprobs_mode, + dllm_block_length=args.dllm_block_length, + dllm_unmasking_strategy=args.dllm_unmasking_strategy, + dllm_denoising_steps=args.dllm_denoising_steps, + dllm_confidence_threshold=args.dllm_confidence_threshold, ) else: from lmdeploy.messages import TurbomindEngineConfig diff --git a/lmdeploy/cli/utils.py b/lmdeploy/cli/utils.py index 4f852b076d..bfd94182d0 100644 --- a/lmdeploy/cli/utils.py +++ b/lmdeploy/cli/utils.py @@ -624,6 +624,36 @@ def logprobs_mode(parser): choices=[None, 'raw_logits', 'raw_logprobs'], help='The mode of logprobs.') + @staticmethod + def dllm_block_length(parser): + """dllm_block_length for dllm.""" + return parser.add_argument('--dllm-block-length', type=int, default=None, help='Block length for dllm') + + @staticmethod + def dllm_unmasking_strategy(parser): + """Dllm unmasking strategy.""" + return parser.add_argument('--dllm-unmasking-strategy', + type=str, + default='low_confidence_dynamic', + choices=['low_confidence_dynamic', 'low_confidence_static', 'sequential'], + help='The unmasking strategy for dllm.') + + @staticmethod + def dllm_denoising_steps(parser): + """Dllm denoising steps.""" + return parser.add_argument('--dllm-denoising-steps', + type=int, + default=None, + help='The number of denoising steps for dllm.') + + @staticmethod + def dllm_confidence_threshold(parser): + """Dllm confidence threshold.""" + return parser.add_argument('--dllm-confidence-threshold', + type=float, + default=0.85, + help='The confidence threshold for dllm.') + # adapted from https://github.com/vllm-project/vllm/blob/main/vllm/utils/__init__.py class FlexibleArgumentParser(argparse.ArgumentParser): diff --git a/lmdeploy/messages.py b/lmdeploy/messages.py index 808c47737c..69ac361abc 100644 --- a/lmdeploy/messages.py +++ b/lmdeploy/messages.py @@ -335,6 +335,12 @@ class PytorchEngineConfig: disable_vision_encoder (bool): Whether to disable loading vision encoder. Default to False. logprobs_mode (str): The mode of logprob, options: ['raw_logits', 'raw_logprobs'] + dllm_block_length (int): Block size of block diffusion model. + dllm_unmasking_strategy (str): Dllm unmasking strategy, options: + ['low_confidence_dynamic', 'low_confidence_static', 'sequential']. + dllm_denoising_steps (int): Dllm denoising steps. + dllm_confidence_threshold (float): dllm unmasking threshold for + dynamic unmasking. """ dtype: str = 'auto' tp: int = 1 @@ -370,6 +376,12 @@ class PytorchEngineConfig: disable_vision_encoder: bool = False logprobs_mode: str = None + # dllm + dllm_block_length: int = None + dllm_unmasking_strategy: str = 'low_confidence_dynamic' + dllm_denoising_steps: int = None + dllm_confidence_threshold: float = 0.85 + role: EngineRole = EngineRole.Hybrid migration_backend: MigrationBackend = MigrationBackend.DLSlime diff --git a/lmdeploy/metrics/stats.py b/lmdeploy/metrics/stats.py index b21eeaca90..6654b3d9d3 100644 --- a/lmdeploy/metrics/stats.py +++ b/lmdeploy/metrics/stats.py @@ -198,7 +198,10 @@ def update_from_output(self, outputs: EngineOutput, req_state: RequestState): outputs (EngineOutput): The output from the engine containing information about the current iteration. req_state (RequestState): The state of the request, including timestamps and token counts. """ - self.new_generation_tokens = outputs.num_token - req_state.generation_tokens + new_generation_tokens = outputs.num_token - req_state.generation_tokens + if new_generation_tokens == 0: + return + self.new_generation_tokens = new_generation_tokens if req_state.first_token_time == 0: # It means the first token is generated in this iteration req_state.first_token_time = outputs.req_metrics.token_timestamp diff --git a/lmdeploy/model.py b/lmdeploy/model.py index 20c48c880a..2b9b3743f1 100644 --- a/lmdeploy/model.py +++ b/lmdeploy/model.py @@ -737,7 +737,7 @@ class HFChatTemplate(BaseChatTemplate): def __init__(self, model_path: str = '', **kwargs): try: - from transformers import AutoTokenizer, PretrainedConfig + from transformers import AutoConfig, AutoTokenizer, PretrainedConfig self.tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True) self.system_start, self.system_end = self._role_instruction('system') self.user_start, self.user_end = self._role_instruction('user') @@ -747,7 +747,10 @@ def __init__(self, model_path: str = '', **kwargs): self.stop_words.append(self.tokenizer.eos_token) if hasattr(self.tokenizer, 'eot_token') and self.tokenizer.eot_token is not None: self.stop_words.append(self.tokenizer.eot_token) - cfg = PretrainedConfig.from_pretrained(model_path, trust_remote_code=True) + try: + cfg = AutoConfig.from_pretrained(model_path, trust_remote_code=True) + except Exception as e: # noqa + cfg = PretrainedConfig.from_pretrained(model_path, trust_remote_code=True) self.is_gpt_oss = getattr(cfg, 'architectures', [''])[0] == 'GptOssForCausalLM' if self.is_gpt_oss: self.stop_words.append('<|call|>') diff --git a/lmdeploy/pytorch/backends/attention.py b/lmdeploy/pytorch/backends/attention.py index 0792c0ccca..0c842e8871 100644 --- a/lmdeploy/pytorch/backends/attention.py +++ b/lmdeploy/pytorch/backends/attention.py @@ -93,6 +93,7 @@ def build( causal: bool = True, use_flash_mla: bool = False, learnable_sink: bool = False, + block_sparse_size: int = 1, **kwargs, ) -> AttentionImpl[T]: """build.""" diff --git a/lmdeploy/pytorch/backends/cuda/attention.py b/lmdeploy/pytorch/backends/cuda/attention.py index 2a9fcc5201..b241c384b2 100644 --- a/lmdeploy/pytorch/backends/cuda/attention.py +++ b/lmdeploy/pytorch/backends/cuda/attention.py @@ -1,5 +1,6 @@ # Copyright (c) OpenMMLab. All rights reserved. +import functools from dataclasses import dataclass from typing import Literal @@ -20,8 +21,8 @@ assert torch.ops.flash_attn_3 is not None use_fa3 = True except Exception: - logger.warning('For higher performance, please install FlashAttention-3 ' - 'https://github.com/Dao-AILab/flash-attention') + logger.debug('For higher performance, please install FlashAttention-3 ' + 'https://github.com/Dao-AILab/flash-attention') @dataclass @@ -62,6 +63,7 @@ def __init__( sliding_window: int = None, logit_softcapping: float = None, causal: bool = True, + block_sparse_size: int = 1, **kwargs, ): super().__init__( @@ -91,6 +93,7 @@ def __init__( world_size, rank = get_tp_world_rank() self.alibi_head_offset = self.num_heads * rank self.alibi_num_heads = self.num_heads * world_size + self.block_sparse_size = block_sparse_size def forward( self, @@ -116,7 +119,7 @@ def forward( kv_flatten_size = attn_metadata.kv_flatten_size quant_policy = attn_metadata.quant_policy if attn_metadata.is_decoding: - max_q_seqlen = 1 + max_q_seqlen = self.block_sparse_size else: max_q_seqlen = query.numel() // (query.size(-1) * query.size(-2)) fill_max_q_seqlen = max_q_seqlen @@ -213,11 +216,21 @@ def forward( logit_softcapping=self.logit_softcapping, sinks=learnable_sink, causal=self.causal, + block_sparse_size=self.block_sparse_size, ) return attn_output +@functools.lru_cache +def use_fa3_warning(): + if use_fa3: + return True + logger.warning('For higher performance, please install FlashAttention-3 ' + 'https://github.com/Dao-AILab/flash-attention') + return False + + class FlashMLAImpl(TritonAttentionImpl): def __init__( @@ -252,6 +265,7 @@ def __init__( from lmdeploy.pytorch.kernels.cuda import flash_mla_fwd self.flash_mla_fwd = flash_mla_fwd assert num_kv_heads == 1, 'MLA requires num kv heads equal to 1' + use_fa3_warning() def forward( self, @@ -512,6 +526,14 @@ def forward( return attn_output +@functools.lru_cache +def _enable_fa3(alibi: bool, learnable_sink: bool, block_sparse_size: int): + enable = not alibi and not learnable_sink and block_sparse_size == 1 + if enable and not use_fa3_warning(): + enable = False + return enable + + class TritonAttentionBuilder(AttentionBuilder[TritonAttentionMetadata]): """Triton attention builder.""" @@ -528,10 +550,13 @@ def build( causal: bool = True, use_flash_mla: bool = False, learnable_sink: bool = False, + block_sparse_size: int = 1, **kwargs, ) -> TritonAttentionImpl: """build.""" + enable_fa3 = _enable_fa3(alibi, learnable_sink, block_sparse_size) if use_flash_mla is True: + logger.debug('Build FlashMLAImpl Attention') return FlashMLAImpl(num_heads, head_size, scale=scale, @@ -542,7 +567,8 @@ def build( logical_softcapping=logical_softcapping, causal=causal, **kwargs) - elif use_fa3 and not alibi and not learnable_sink: + elif enable_fa3: + logger.debug('Build FA3Impl Attention') return FA3Impl(num_heads, head_size, scale=scale, @@ -554,6 +580,7 @@ def build( causal=causal, **kwargs) else: + logger.debug('Build TritonAttentionImpl Attention') return TritonAttentionImpl(num_heads, head_size, scale=scale, @@ -563,4 +590,5 @@ def build( sliding_window=sliding_window, logical_softcapping=logical_softcapping, causal=causal, + block_sparse_size=block_sparse_size, **kwargs) diff --git a/lmdeploy/pytorch/backends/cuda/graph_runner.py b/lmdeploy/pytorch/backends/cuda/graph_runner.py index ccc413be75..deb6c66bfd 100644 --- a/lmdeploy/pytorch/backends/cuda/graph_runner.py +++ b/lmdeploy/pytorch/backends/cuda/graph_runner.py @@ -9,9 +9,11 @@ from lmdeploy.pytorch.config import BackendConfig, CacheConfig, ModelConfig from lmdeploy.pytorch.model_inputs import StepContext, get_step_ctx_manager from lmdeploy.pytorch.models.utils.cudagraph import CudaGraphMeta +from lmdeploy.pytorch.strategies.base import StrategyFactoryBase from lmdeploy.utils import get_logger from ..graph_runner import GraphRunner +from .attention import TritonAttentionMetadata logger = get_logger('lmdeploy') @@ -146,6 +148,11 @@ def __init__(self, model: torch.nn.Module, model_config: ModelConfig, cache_conf self._runner_map: Dict[Any, CUDASingleGraphRunner] = dict() self.has_try_compile_model: bool = False + # strategy factory + build_ctx = model.ctx_mgr.build_ctx + strategy_factory: StrategyFactoryBase = build_ctx.strategy_factory + self.cudagraph_strategy = strategy_factory.build_cudagraph_strategy() + def check_enable_graph(self): """Check enable graph.""" if self.backend_config.eager_mode: @@ -173,18 +180,24 @@ def _get_capture_tokens(self, batch_size: int): assert False, f'Unsupported batch_size={batch_size}' def get_graph_key(self, input_ids: torch.Tensor, position_ids: torch.Tensor, past_key_values: List, - attn_metadata: Any, inputs_embeds: torch.Tensor, **kwargs): + attn_metadata: TritonAttentionMetadata, inputs_embeds: torch.Tensor, **kwargs): """Get graph key.""" context = self.ctx_mgr.current_context() is_decoding = context.is_decoding - num_tokens = input_ids.numel() + batch_size = attn_metadata.q_seqlens.size(0) meta = self.get_meta() enable_microbatch = get_step_ctx_manager().current_context().enable_microbatch if meta.padding_batch_size is None: - new_num_tokens = self._get_capture_tokens(num_tokens) + batch_size = self._get_capture_tokens(batch_size) else: - new_num_tokens = self._get_capture_tokens(meta.padding_batch_size) - return (new_num_tokens, is_decoding, enable_microbatch) + batch_size = self._get_capture_tokens(meta.padding_batch_size) + return (batch_size, is_decoding, enable_microbatch) + + def _get_max_tokens(self, graph_key: tuple): + max_batches = graph_key[0] + is_decoding = graph_key[1] + assert is_decoding + return self.cudagraph_strategy.get_max_tokens(max_batches) def __call__(self, **kwargs): """call.""" @@ -198,10 +211,10 @@ def __call__(self, **kwargs): return self.model(**kwargs) graph_key = self.get_graph_key(**kwargs) - max_tokens = graph_key[0] + max_batches = graph_key[0] is_decoding = graph_key[1] if graph_key not in self._runner_map: - max_batches = max_tokens if is_decoding else self.max_batches + max_tokens = self._get_max_tokens(graph_key) runner = CUDASingleGraphRunner(self.model, max_batches=max_batches, max_tokens=max_tokens, diff --git a/lmdeploy/pytorch/check_env/transformers.py b/lmdeploy/pytorch/check_env/transformers.py index defd43303b..20102b1deb 100644 --- a/lmdeploy/pytorch/check_env/transformers.py +++ b/lmdeploy/pytorch/check_env/transformers.py @@ -4,7 +4,7 @@ from .base import BaseChecker MIN_TRANSFORMERS_VERSION = '4.33.0' -MAX_TRANSFORMERS_VERSION = '4.53.3' +MAX_TRANSFORMERS_VERSION = '4.56.1' class TransformersChecker(BaseChecker): diff --git a/lmdeploy/pytorch/config.py b/lmdeploy/pytorch/config.py index 05c716c6a9..ac3459e045 100644 --- a/lmdeploy/pytorch/config.py +++ b/lmdeploy/pytorch/config.py @@ -1,4 +1,5 @@ # Copyright (c) OpenMMLab. All rights reserved. +import enum from dataclasses import dataclass from typing import Any, Dict, List, Literal @@ -27,7 +28,10 @@ def _update_torch_dtype(config: 'ModelConfig', dtype: str): config.dtype = torch.float16 return config - torch_dtype = getattr(config.hf_config, 'torch_dtype', None) + torch_dtype = getattr(config.hf_config, 'dtype', None) + if torch_dtype is None: + torch_dtype = getattr(config.hf_config, 'torch_dtype', None) + # deal with case when torch_dtype is not string but torch.dtype if isinstance(torch_dtype, torch.dtype): torch_dtype = str(torch_dtype).split('.')[1] @@ -200,6 +204,9 @@ class ModelConfig: cogvlm_style: bool = False custom_module_map: Dict[str, setattr] = None use_flash_mla: bool = False + model_paradigm: str = 'ar' + dllm_mask_token: int = 0 + dllm_block_length: int = None def get_head_size(self): """Get head size.""" @@ -285,6 +292,38 @@ def from_hf_config(cls, return model_config +class UnmaskingStrategy(enum.Enum): + """Unmasking Strategy.""" + + # unmasking from left to right + SEQUENTIAL = enum.auto() + # unmasking with confidence threshold + LOW_CONFIDENCE_DYNAMIC = enum.auto() + # unmasking with topk in a block + LOW_CONFIDENCE_STATIC = enum.auto() + + @classmethod + def from_str(cls, strategy: str): + """From string.""" + strategy = strategy.lower() + if strategy == 'sequential': + return cls.SEQUENTIAL + elif strategy == 'low_confidence_dynamic': + return cls.LOW_CONFIDENCE_DYNAMIC + elif strategy == 'low_confidence_static': + return cls.LOW_CONFIDENCE_STATIC + else: + raise ValueError(f'Unknown unmasking strategy: {strategy}') + + +@dataclass +class DLLMConfig: + block_length: int = 1 + unmasking_strategy: UnmaskingStrategy = UnmaskingStrategy.LOW_CONFIDENCE_DYNAMIC + denoising_steps: int = None + confidence_threshold: float = 0.85 + + @dataclass class MiscConfig: prefill_interval: int = 16 @@ -294,15 +333,22 @@ class MiscConfig: hf_overrides: Dict[str, Any] = None disable_vision_encoder: bool = False logprobs_mode: str = None + dllm_config: DLLMConfig = None @classmethod def from_engine_config(cls, engine_config: PytorchEngineConfig): """From engine config.""" + dllm_unmasking_strategy = UnmaskingStrategy.from_str(engine_config.dllm_unmasking_strategy) + dllm_config = DLLMConfig(block_length=engine_config.dllm_block_length, + unmasking_strategy=dllm_unmasking_strategy, + denoising_steps=engine_config.dllm_denoising_steps, + confidence_threshold=engine_config.dllm_confidence_threshold) misc_config = cls(custom_module_map=engine_config.custom_module_map, empty_init=engine_config.empty_init, prefill_interval=engine_config.prefill_interval, model_format=engine_config.model_format, hf_overrides=engine_config.hf_overrides, disable_vision_encoder=engine_config.disable_vision_encoder, - logprobs_mode=engine_config.logprobs_mode) + logprobs_mode=engine_config.logprobs_mode, + dllm_config=dllm_config) return misc_config diff --git a/lmdeploy/pytorch/configurations/sdar.py b/lmdeploy/pytorch/configurations/sdar.py new file mode 100644 index 0000000000..edf9cd3cad --- /dev/null +++ b/lmdeploy/pytorch/configurations/sdar.py @@ -0,0 +1,18 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from .default import AutoModelConfigBuilder, DefaultModelConfigBuilder + + +class SDARModelConfigBuilder(AutoModelConfigBuilder): + + @classmethod + def condition(cls, hf_config): + """config.""" + return hf_config.model_type in ['sdar', 'sdar_moe'] + + @classmethod + def build(cls, hf_config, model_path: str = None, **kwargs): + """build.""" + cfg = DefaultModelConfigBuilder.build(hf_config, model_path, **kwargs) + cfg.dllm_mask_token = 151669 + cfg.model_paradigm = 'dllm' + return cfg diff --git a/lmdeploy/pytorch/consts.py b/lmdeploy/pytorch/consts.py new file mode 100644 index 0000000000..93dc2f4d71 --- /dev/null +++ b/lmdeploy/pytorch/consts.py @@ -0,0 +1,5 @@ +# Copyright (c) OpenMMLab. All rights reserved. +# dllm +DLLM_MASKED = 0 +DLLM_UNMASKED = 1 +DLLM_CACHED = 2 diff --git a/lmdeploy/pytorch/disagg/backend/__init__.py b/lmdeploy/pytorch/disagg/backend/__init__.py index 3ab02d3bd6..2309e16c02 100644 --- a/lmdeploy/pytorch/disagg/backend/__init__.py +++ b/lmdeploy/pytorch/disagg/backend/__init__.py @@ -7,7 +7,7 @@ logger.debug('Registering DLSlime Backend') from .dlslime import DLSlimeBackend except ImportError: - logger.warning('Disable DLSlime Backend') + logger.debug('Disable DLSlime Backend') try: logger.debug('Registering Mooncake Backend') diff --git a/lmdeploy/pytorch/engine/engine.py b/lmdeploy/pytorch/engine/engine.py index f94c6c9a46..2179b5a99f 100644 --- a/lmdeploy/pytorch/engine/engine.py +++ b/lmdeploy/pytorch/engine/engine.py @@ -20,13 +20,13 @@ from ..adapter.adapter import AdapterManager from ..config import BackendConfig, CacheConfig, DistConfig, MiscConfig, ModelConfig, SchedulerConfig -from ..messages import MessageStatus, SchedulerSequence +from ..messages import MessageStatus, SchedulerSequence, UpdateTokenMode from ..model_inputs import ModelInputs, VisionModelInputs from ..paging import Scheduler +from ..strategies import build_strategy_factory from .base import EngineBase from .engine_checker import EngineChecker from .executor import build_executor -from .logits_process import SamplingInputs from .model_agent import BatchedOutputs from .request import Request, RequestManager, RequestType, Response @@ -81,6 +81,15 @@ def _update_engine_config(engine_config: PytorchEngineConfig): if engine_config.max_batch_size is None: engine_config.max_batch_size = get_max_batch_size(engine_config.device_type) + if engine_config.dllm_block_length is not None: + max_prefill_token_num = engine_config.max_prefill_token_num + max_batch_size = engine_config.max_batch_size + if max_batch_size * engine_config.dllm_block_length > max_prefill_token_num: + engine_config.max_batch_size = max_prefill_token_num // engine_config.dllm_block_length + logger.warning(f'Update max_batch_size to {engine_config.max_batch_size} ' + f'since dllm_block_length({engine_config.dllm_block_length}) * max_batch_size ' + f'({max_batch_size}) > max_prefill_token_num ({max_prefill_token_num}).') + if engine_config.dp != 1: if engine_config.tp == 1 and engine_config.ep == 1: engine_config.dp = 1 @@ -139,6 +148,13 @@ def _build_misc_config(engine_config: PytorchEngineConfig): return misc_config +def _build_seq_meta(cache_config: CacheConfig, strategy: Any): + from lmdeploy.pytorch.messages import SequenceMeta + + seq_meta = SequenceMeta(cache_config.block_size, strategy=strategy) + return seq_meta + + class CounterEvent: def __init__(self): @@ -366,10 +382,19 @@ def __init__(self, dtype=engine_config.dtype) self.executor.init() + # strategies + self.strategy_factory = build_strategy_factory(self.model_config, self.executor.misc_config) + self.sampling_strategy = self.strategy_factory.build_sampling_strategy() + self.model_agent_strategy = self.strategy_factory.build_model_agent_strategy() + self.engine_strategy = self.strategy_factory.build_engine_strategy(cache_config=cache_config, + scheduler_config=scheduler_config) + self.seq_strategy = self.strategy_factory.build_sequence_strategy() + self.input_processor = self.executor.get_input_processor() cache_config = self.executor.cache_config self.adapter_manager = self._build_adapter_manager(adapters) - self.scheduler = Scheduler(scheduler_config, cache_config) + self.seq_meta = _build_seq_meta(cache_config, strategy=self.seq_strategy) + self.scheduler = Scheduler(scheduler_config, cache_config, seq_meta=self.seq_meta) # engine args self.model_path = model_path @@ -378,6 +403,7 @@ def __init__(self, self.cache_config = cache_config self.backend_config = backend_config self.dist_config = dist_config + self.misc_config = self.executor.misc_config self.max_session_len = self._get_max_session_len() self.engine_config.num_cpu_blocks = self.cache_config.num_cpu_blocks self.engine_config.num_gpu_blocks = self.cache_config.num_gpu_blocks @@ -575,7 +601,7 @@ def __update_max_new_tokens(msg): max_session_len = self.max_session_len sampling_param = msg.sampling_param max_new_tokens = sampling_param.max_new_tokens - num_all_tokens = msg.num_all_tokens() + num_all_tokens = msg.num_valid_ids if max_new_tokens + num_all_tokens > max_session_len: logger.warning( f'session[{msg.session_id}]: num tokens is larger than max session len {max_session_len}. ' @@ -592,14 +618,12 @@ def __update_max_new_tokens(msg): sess = scheduler.sessions[session_id] # TODO: support 1 session n sequence sampling_param = req.data['sampling_param'] - return_logits = sampling_param.out_logits if len(sess.sequences) == 0: migration_request = req.data.get('migration_request') assert len(req.data['token_ids']) > 0, ('Empty input is not allowed.') sess.add_sequence(req.data['token_ids'], sampling_param=sampling_param, adapter_name=req.data['adapter_name'], - return_logits=return_logits, multimodals=req.data.get('input_multimodals'), input_embeddings=req.data.get('input_embeddings', ), migration_request=migration_request, @@ -617,11 +641,9 @@ def __update_max_new_tokens(msg): req.data['token_ids'], multimodals=req.data.get('input_multimodals'), embeddings=req.data.get('input_embeddings'), - append_tokens=True, + mode=UpdateTokenMode.INPUTS, ) - msg.num_new_tokens = 0 msg.sampling_param = sampling_param - msg.return_logits = return_logits msg.status = MessageStatus.WAITING __update_max_new_tokens(msg) @@ -659,10 +681,11 @@ def __get_vlm_embeddings(): ] input_embedding_indexing = torch.zeros((batch_size, max_q_seq_length), dtype=torch.bool) for msg_id, msg in enumerate(messages): + num_history_ids = msg.num_history_ids for emb in msg.input_embeddings: # make slice index relative to embeddings - emb_start = emb.start - msg.history_len - emb_end = emb.end - msg.history_len + emb_start = emb.start - num_history_ids + emb_end = emb.end - num_history_ids input_embedding_indexing[msg_id][emb_start:emb_end] = True return (input_embeddings, input_embedding_indexing, input_embedding_ranges) @@ -716,7 +739,7 @@ def create_model_inputs(self, messages: SeqList, is_prefill: bool): """ batch_size = len(messages) # history lengths - history_lengths = torch.tensor([msg.history_len for msg in messages]) + history_lengths = torch.tensor([msg.num_history_ids for msg in messages]) # input ids token_ids = [msg.token_ids for msg in messages] @@ -729,8 +752,8 @@ def create_model_inputs(self, messages: SeqList, is_prefill: bool): seq_length = torch.tensor(seq_length, dtype=torch.long) max_q_seqlen = seq_length.max().item() else: - seq_length = torch.ones(batch_size, dtype=torch.long) - max_q_seqlen = 1 + max_q_seqlen = len(token_ids[0]) + seq_length = torch.full((batch_size, ), max_q_seqlen, dtype=torch.long) kv_seqlens = seq_length + history_lengths max_kv_seqlen = kv_seqlens.max().item() sum_kv_seqlen = kv_seqlens.sum().item() @@ -780,23 +803,6 @@ def create_model_inputs(self, messages: SeqList, is_prefill: bool): return model_inputs - def update_running(self, running: SeqList, next_token_ids: torch.Tensor, stopped: torch.Tensor, - model_metas: List[Dict[str, Any]]): - """Update scheduler.""" - if model_metas is None: - model_metas = [None] * len(running) - next_token_ids = next_token_ids.numpy() - for token, msg, stop, model_meta in zip(next_token_ids, running, stopped, model_metas): - if msg.status != MessageStatus.LOCKED: - continue - update_token = token - - # fill token - msg.update_token_ids(update_token, model_meta=model_meta) - msg.num_new_tokens += 1 - if stop: - msg.status = MessageStatus.TO_BE_MIGRATED if msg.preserve_cache else MessageStatus.STOPPED - def update_running_migration(self, running: SeqList, next_token_ids: np.ndarray, stopped: torch.Tensor, model_metas: List[Dict[str, Any]]): """Update scheduler.""" @@ -808,40 +814,40 @@ def update_running_migration(self, running: SeqList, next_token_ids: np.ndarray, update_token = token # fill token - msg.update_token_ids(update_token, model_meta=model_meta) - msg.num_new_tokens += 1 + msg.update_token_ids(update_token, model_meta=model_meta, mode=UpdateTokenMode.PREFILL) if stop: update_token = _EMPTY_TOKEN - msg.update_token_ids(update_token, model_meta=model_meta) + msg.update_token_ids(update_token, model_meta=model_meta, mode=UpdateTokenMode.PREFILL) msg.status = MessageStatus.STOPPED def _make_infer_outputs( self, batched_outputs: BatchedOutputs, running: SeqList, + is_decoding: bool, ): """Make infer output.""" new_token_timestamp = batched_outputs.new_token_timestamp - next_token_ids = batched_outputs.next_token_ids logits = batched_outputs.logits - stopped = batched_outputs.stopped - model_metas = batched_outputs.model_metas logprobs = batched_outputs.logprobs seq_length = [seq.num_token_ids for seq in running] is_run = [seq.status == MessageStatus.LOCKED for seq in running] - stopped = stopped.tolist() - self.update_running(running, next_token_ids, stopped, model_metas) + self.seq_strategy.update_running(running=running, batched_outputs=batched_outputs, is_decoding=is_decoding) # generate output outputs: Dict[int, InferOutput] = dict() for idx, msg in enumerate(running): if not is_run[idx]: continue - token_ids = msg.all_ids[-msg.num_new_tokens:] + token_ids = msg.generated_ids finish = msg.status == MessageStatus.STOPPED or msg.status == MessageStatus.TO_BE_MIGRATED if not finish and len(token_ids) == 0: continue + resp_data = msg.resp.data + if resp_data is not None and len(resp_data.get('token_ids', [])) == len(token_ids): + # no new tokens + continue session_id = msg.session_id if msg.resp_cache: cache_block_ids = self.scheduler.block_manager.get_block_table(msg).tolist() @@ -870,51 +876,6 @@ def _make_infer_outputs( def _make_forward_inputs(self, prefill: bool, enable_empty: bool = False): """Make forward inputs.""" - prefill_interval = self.scheduler_config.prefill_interval - - def __gather_all_ids(seqs: SeqList, sampling_inputs: SamplingInputs): - """Gather history.""" - if sampling_inputs.repetition_penalty is None and not any(sampling_inputs.logits_processors): - return None - batch = len(seqs) - max_len = max(seq.num_all_ids for seq in seqs) - pad_id = self.model_config.bos_token_id - pad_id = 0 if pad_id is None else pad_id - output = torch.full((batch, max_len), pad_id, dtype=torch.int64) - for idx, seq in enumerate(seqs): - h_len = seq.num_all_ids - if h_len == 0: - continue - h_ids = torch.from_numpy(seq.all_ids) - output[idx, -h_len:] = h_ids - return output - - def __gather_guided_input_ids(seqs: SeqList, sampling_inputs: SamplingInputs): - """Gather input ids for guided decode.""" - if not any(sampling_inputs.response_formats or ()): - return None - batch = len(seqs) - max_len = max(seq.num_new_tokens for seq in seqs) - pad_id = self.model_config.bos_token_id - pad_id = 0 if pad_id is None else pad_id - output = torch.full((batch, max_len), pad_id, dtype=torch.int64) - for idx, seq in enumerate(seqs): - h_len = seq.num_new_tokens - if h_len == 0: - continue - h_ids = torch.from_numpy(seq.all_ids[-seq.num_new_tokens:]) - output[idx, -h_len:] = h_ids - return output - - def __get_num_appendable_ids(seqs: SeqList): - """Get num appendable ids.""" - ret = [seq.sampling_param.max_new_tokens - seq.num_new_tokens for seq in seqs] - return torch.tensor(ret) - - def __get_num_ignore_eos(seqs: SeqList): - """Get num ignore eos.""" - ret = [seq.sampling_param.min_new_tokens - seq.num_new_tokens for seq in seqs] - return torch.tensor(ret) def __need_logits(seqs: SeqList): """Need logits.""" @@ -923,7 +884,8 @@ def __need_logits(seqs: SeqList): scheduler = self.scheduler logger.debug(f'Make forward inputs with prefill={prefill}, enable_empty={enable_empty}') - scheduler_output = scheduler.schedule(is_prefill=prefill, prealloc_size=prefill_interval) + prealloc_size = self.engine_strategy.get_prealloc_size(not prefill) + scheduler_output = scheduler.schedule(is_prefill=prefill, prealloc_size=prealloc_size) if enable_empty and len(scheduler_output.running) == 0: return None @@ -931,9 +893,10 @@ def __need_logits(seqs: SeqList): # schedule decoding if no valid prefill reqs. if prefill and len(scheduler_output.running) == 0 and self.engine_config.role != EngineRole.Prefill: prefill = False - scheduler_output = scheduler.schedule(is_prefill=prefill, prealloc_size=prefill_interval) + prealloc_size = self.engine_strategy.get_prealloc_size(not prefill) + scheduler_output = scheduler.schedule(is_prefill=prefill, prealloc_size=prealloc_size) - num_loops = 1 if prefill else prefill_interval + num_loops = self.engine_strategy.get_num_loops(not prefill) running = scheduler_output.running swap_in_map = scheduler_output.swap_in_map swap_out_map = scheduler_output.swap_out_map @@ -943,12 +906,10 @@ def __need_logits(seqs: SeqList): # create inputs inputs = self.create_model_inputs(running, prefill) - sampling_inputs = SamplingInputs.from_sampling_params(running) - all_ids = __gather_all_ids(running, sampling_inputs) - guided_input_ids = __gather_guided_input_ids(running, sampling_inputs) - num_appendable_ids = __get_num_appendable_ids(running) - num_ignore_eos = __get_num_ignore_eos(running) + sampling_inputs = self.sampling_strategy.make_sampling_inputs(running) return_logits = __need_logits(running) + extra_inputs = self.model_agent_strategy.make_extra_inputs(running) + stopping_criteria = self.model_agent_strategy.make_stopping_criteria(running) sync_long_context = inputs.input_ids.numel() > self.cache_config.max_prefill_token_num return dict( @@ -957,14 +918,12 @@ def __need_logits(seqs: SeqList): swap_in_map=swap_in_map, swap_out_map=swap_out_map, loop_count=num_loops, - all_ids=all_ids, - guided_input_ids=guided_input_ids, sampling_inputs=sampling_inputs, - num_appendable_ids=num_appendable_ids, - num_ignore_eos=num_ignore_eos, + stopping_criteria=stopping_criteria, return_logits=return_logits, is_dummy=False, sync_long_context=sync_long_context, + extra_inputs=extra_inputs, ) async def _await_forward_event(self, forward_event: asyncio.Event): @@ -1132,6 +1091,7 @@ async def _async_loop_main( forward_event.set() num_loops = forward_inputs['loop_count'] + is_decoding = forward_inputs['inputs'].is_decoding running = next_running next_running = None scheduler.lock_running(running) @@ -1145,7 +1105,7 @@ async def _async_loop_main( # send output out = await self.executor.get_output_async() if out is not None: - step_outputs = self._make_infer_outputs(out, running=running) + step_outputs = self._make_infer_outputs(out, running=running, is_decoding=is_decoding) resp_que.put_nowait(step_outputs) # lock forward event diff --git a/lmdeploy/pytorch/engine/executor/base.py b/lmdeploy/pytorch/engine/executor/base.py index 70a906df0f..9e50843a80 100644 --- a/lmdeploy/pytorch/engine/executor/base.py +++ b/lmdeploy/pytorch/engine/executor/base.py @@ -35,7 +35,7 @@ def __init__(self, self.cache_config = cache_config self.backend_config = backend_config self.dist_config = dist_config - self.misc_config = misc_config, + self.misc_config = misc_config self.dp = dist_config.dp self.tp = dist_config.tp self.world_size = dist_config.world_size diff --git a/lmdeploy/pytorch/engine/logits_process.py b/lmdeploy/pytorch/engine/logits_process.py index 214f83256e..b30fbb3992 100644 --- a/lmdeploy/pytorch/engine/logits_process.py +++ b/lmdeploy/pytorch/engine/logits_process.py @@ -109,6 +109,9 @@ def _guided_sampling(response_formats: Tuple[Dict], scores: torch.Tensor, guided return scores +SeqList = List[SchedulerSequence] + + @dataclass class SamplingInputs: temperature: torch.Tensor = None @@ -120,142 +123,17 @@ class SamplingInputs: top_k: torch.LongTensor = None top_p: torch.Tensor = None min_p: torch.Tensor = None - random_seeds: int = None - random_offsets: int = None + random_seeds: torch.Tensor = None + random_offsets: torch.Tensor = None max_top_k: int = 1 min_top_p: float = 1.0 response_formats: Tuple[str] = () logits_processors: List[List[LogitsProcessor]] = None max_num_logprobs: Optional[int] = None - - @classmethod - def from_sampling_params(cls, seqs: List[SchedulerSequence]): - """From samplingg params.""" - batch_size = len(seqs) - temperature = [None] * batch_size - repetition_penalty = [None] * batch_size - top_k = [None] * batch_size - top_p = [None] * batch_size - min_p = [None] * batch_size - bad_words = [None] * batch_size - stop_words = [None] * batch_size - random_seeds = [torch.seed() & 0xffffffff] * batch_size - random_offsets = [None] * batch_size - response_formats = [None] * batch_size - logits_processors = [None] * batch_size - num_logprobs = [None] * batch_size - - def __gather_params(): - """Gather params.""" - for idx, seq in enumerate(seqs): - param = seq.sampling_param - temperature[idx] = param.temperature - repetition_penalty[idx] = param.repetition_penalty - top_k[idx] = param.top_k - top_p[idx] = param.top_p - min_p[idx] = param.min_p - random_offsets[idx] = seq.random_offsets - response_formats[idx] = param.response_format - if param.random_seed is not None: - random_seeds[idx] = param.random_seed & 0xffffffff - - bw = param.bad_words - sw = param.stop_words - if (not param.ignore_eos and seq.num_new_tokens < param.min_new_tokens): - bw = bw + sw - bad_words[idx] = bw - stop_words[idx] = sw - logits_processors[idx] = param.logits_processors - num_logprobs[idx] = param.num_logprobs - - def __get_topp(top_p): - """Get topp.""" - min_top_p = min(top_p) - if min_top_p == 1.0: - top_p = None - else: - top_p = torch.tensor(top_p) - return top_p, min_top_p - - def __get_minp(min_p): - """Get minp.""" - max_min_p = max(min_p) - if max_min_p == 0.0: - min_p = None - else: - min_p = torch.Tensor(min_p) - return min_p - - def __get_bad_words(bad_words): - """Get bad words.""" - max_bw_len = max(len(bw) for bw in bad_words) - if max_bw_len == 0: - return None, None - if all(len(bw) == max_bw_len for bw in bad_words): - ret = torch.tensor(bad_words) - mask = torch.ones_like(ret, dtype=bool) - return ret, mask - ret = torch.full((batch_size, max_bw_len), -1, dtype=torch.int64) - for idx, bw in enumerate(bad_words): - bw_len = len(bw) - if bw_len == 0: - continue - bw = ret.new_tensor(bw) - ret[idx, :bw_len] = bw - - mask = ret >= 0 - ret = ret.where(mask, 0) - return ret, mask - - __gather_params() - - if all(rp == 1.0 for rp in repetition_penalty): - repetition_penalty = None - else: - repetition_penalty = torch.tensor(repetition_penalty) - - temperature = torch.tensor(temperature) - - bad_words, bad_mask = __get_bad_words(bad_words) - stop_words, stop_mask = __get_bad_words(stop_words) - - max_top_k = max(top_k) - if min(top_k) <= 0: - max_top_k = 0 - if max_top_k == 1: - top_k = None - top_p, min_top_p = None, 1.0 - min_p = None - random_seeds = None - random_offsets = None - else: - top_k = torch.tensor(top_k) - top_p, min_top_p = __get_topp(top_p) - min_p = __get_minp(min_p) - random_seeds = torch.tensor(random_seeds) - random_offsets = torch.tensor(random_offsets) - - max_num_logprobs = max(num_logprobs) - - sampling_input = cls( - temperature=temperature, - bad_words=bad_words, - bad_mask=bad_mask, - stop_words=stop_words, - stop_mask=stop_mask, - repetition_penalty=repetition_penalty, - top_k=top_k, - top_p=top_p, - min_p=min_p, - random_seeds=random_seeds, - random_offsets=random_offsets, - response_formats=tuple(response_formats), - max_top_k=max_top_k, - min_top_p=min_top_p, - logits_processors=logits_processors, - max_num_logprobs=max_num_logprobs, - ) - return sampling_input + all_ids: Optional[torch.Tensor] = None + guided_input_ids: Optional[torch.Tensor] = None + num_ignore_eos: torch.Tensor = None + batch_size: int = 0 def to_device(self, device: str, non_blocking: bool = False): """To device.""" @@ -284,12 +162,10 @@ class FusedLogitsProcessor: def __init__(self, sampling_inputs: SamplingInputs, - ignore_eos: torch.Tensor, tokenizer: Optional[Tokenizer] = None, sampling_vocab_size: Optional[int] = None, logprobs_mode: Optional[str] = None): self.sampling_inputs: SamplingInputs = sampling_inputs - self.ignore_eos = ignore_eos self.tokenizer = tokenizer self.sampling_vocab_size = sampling_vocab_size self.logprobs_mode = logprobs_mode @@ -300,12 +176,9 @@ async def _wait_stream_once(self): if not stream.query(): await asyncio.sleep(0) - async def __call__(self, all_ids: torch.LongTensor, guided_input_ids: torch.LongTensor, - scores: torch.FloatTensor) -> torch.FloatTensor: + async def __call__(self, scores: torch.FloatTensor) -> torch.FloatTensor: r""" Args: - all_ids (torch.LongTensor): All the token ids. - guided_input_ids (torch.LongTensor): Guided prompt ids. scores (torch.FloatTensor): Prediction scores of a language modeling head. These can be logits for each vocabulary when not using @@ -331,6 +204,8 @@ async def __call__(self, all_ids: torch.LongTensor, guided_input_ids: torch.Long logprobs = None sampling_inputs = self.sampling_inputs + all_ids = sampling_inputs.all_ids + guided_input_ids = sampling_inputs.guided_input_ids custom_logits_processors = self.sampling_inputs.logits_processors if any(custom_logits_processors): @@ -352,8 +227,9 @@ async def __call__(self, all_ids: torch.LongTensor, guided_input_ids: torch.Long stop_words = sampling_inputs.stop_words if stop_words is not None: + ignore_eos = sampling_inputs.num_ignore_eos > 0 stop_mask = sampling_inputs.stop_mask - stop_mask = torch.where(self.ignore_eos[:, None], stop_mask, False) + stop_mask = torch.where(ignore_eos[:, None], stop_mask, False) scores = _process_bad_words_(scores, stop_words, stop_mask) if guided_input_ids is not None: diff --git a/lmdeploy/pytorch/engine/model_agent.py b/lmdeploy/pytorch/engine/model_agent.py index 3457e5efe9..7b332b6f0c 100644 --- a/lmdeploy/pytorch/engine/model_agent.py +++ b/lmdeploy/pytorch/engine/model_agent.py @@ -25,6 +25,8 @@ from ..distributed import DistContext, get_dist_manager from ..model_inputs import ModelInputs, step_ctx_manager from ..models.patch import BuildModelContext, add_adapters, build_patched_model, update_custom_module_map +from ..strategies import build_strategy_factory +from ..strategies.base.model_agent import ExtraInputs, ExtraOutputs, StoppingCriteria from ..utils import get_gpu_memory from ..weight_loader.model_weight_loader import load_model_weights from .cache_engine import CacheEngine @@ -63,10 +65,12 @@ def to_tensor(self): class BatchedOutputs: next_token_ids: torch.Tensor stopped: torch.Tensor + stop_pos: Optional[torch.Tensor] = None logits: Optional[torch.Tensor] = None model_metas: List[Dict[str, Any]] = None logprobs: Optional[BatchedLogProbs] = None new_token_timestamp: int = 0 + extra_outputs: Optional[ExtraOutputs] = None def to_cpu(self): """To cpu.""" @@ -199,6 +203,8 @@ def msg_with_rank(rank: int, msg: str): def cache_swapping(cache_engine: CacheEngine, swap_in_map: dict, swap_out_map: dict): """Perform cache swapping.""" issued_cache_op = False + swap_in_map = swap_in_map or dict() + swap_out_map = swap_out_map or dict() if len(swap_in_map) > 0: cache_engine.swap_in(swap_in_map) issued_cache_op = True @@ -242,24 +248,11 @@ def model_forward( # InternVL-3.5-Flash will change the seqlen, model_metas during forward model_metas = context.model_metas - seq_length = context.q_seqlens + seq_length = context.q_seqlens[:len(inputs.seq_length)] return dict(hidden_states=output, model_metas=model_metas, seq_length=seq_length) -@record_function('stopping_criteria') -def _batch_stopping_criteria(token_ids: torch.Tensor, stop_words: torch.Tensor, num_appendable_ids: torch.Tensor): - """Batched stopping criteria.""" - num_appendable_ids = num_appendable_ids - 1 - stopped = num_appendable_ids <= 0 - if stop_words is not None: - sw_stopped = (token_ids[:, None] == stop_words).any(1) - stopped = stopped | sw_stopped - one_ids = torch.clamp_max(num_appendable_ids, 0) - num_appendable_ids = torch.where(sw_stopped, one_ids, num_appendable_ids) - return stopped, num_appendable_ids - - def _try_to_cuda(val, non_blocking: bool = False): if val is None: return val @@ -369,6 +362,11 @@ def __init__(self, self.enable_microbatch_decode_batchsize_threshold = \ int(getenv('ENABLE_MICROBATCH_DECODE_BATCHSIZE_THRESHOLD', 2)) + # strategy + self.strategy_factory = build_strategy_factory(model_config, misc_config) + self.inputs_strategy = self.strategy_factory.build_model_inputs_strategy() + self.agent_strategy = self.strategy_factory.build_model_agent_strategy() + @contextmanager def all_context(self): device_mgr = get_device_manager() @@ -399,33 +397,42 @@ def warmup(self): num_tokens = max_batches # warmup prefill - inputs = ModelInputs.make_dummy(max_batches, - is_decoding=False, - device='cuda', - vocab_size=self.model_config.vocab_size) - self._forward_impl(inputs, swap_in_map=dict(), swap_out_map=dict()) + inputs = self.inputs_strategy.make_dummy(max_batches, + is_decoding=False, + device='cuda', + vocab_size=self.model_config.vocab_size) + self._forward_impl(inputs) # warmup decoding(with cuda graph) capture_batch_sizes = self.patched_model.get_capture_batch_sizes() capture_batch_sizes = sorted(capture_batch_sizes, reverse=True) for num_tokens in capture_batch_sizes: - inputs = ModelInputs.make_dummy(num_tokens, - is_decoding=True, - device='cuda', - vocab_size=self.model_config.vocab_size) - self._forward_impl(inputs, swap_in_map=dict(), swap_out_map=dict()) + inputs = self.inputs_strategy.make_dummy(num_tokens, + is_decoding=True, + device='cuda', + vocab_size=self.model_config.vocab_size) + self._forward_impl(inputs) + + def _slice_outs(self, inputs: torch.Tensor, seq_length: torch.LongTensor): + """Slice outputs.""" + return self.agent_strategy.slice_outputs(inputs, seq_length) + + def _postprocess_forward_output(self, output: dict, inputs: ModelInputs): + """Post process forward output.""" + hidden_states = output['hidden_states'] + seq_length = output.get('seq_length', inputs.seq_length) + hidden_states = self._slice_outs(hidden_states[0], seq_length)[None] + output['hidden_states'] = hidden_states + return output async def _async_model_forward( self, inputs: ModelInputs, - swap_in_map: Dict, - swap_out_map: Dict, return_logits: bool, sync_long_context: bool, ): """Model forward.""" max_prefill_token_num = self.cache_config.max_prefill_token_num - swap_done = False class _OutputGather: """Output gather.""" @@ -433,7 +440,8 @@ class _OutputGather: def __init__(self, max_seq_len): self._max_seq_len = max_seq_len self._start = 0 - self._output = None + self._output: torch.Tensor = None + self._device: torch.device = None def gather(self, output): """gather.""" @@ -448,6 +456,7 @@ def gather(self, output): seq_len = tmp_output.size(-2) if out_logits is None: out_logits = tmp_output.new_empty(1, self._max_seq_len, tmp_output.size(-1), device='cpu') + self._device = tmp_output.device out_logits[:, start:start + seq_len].copy_(tmp_output, non_blocking=True) self._start = start + seq_len self._output = out_logits @@ -457,16 +466,9 @@ def get_output(self): if not return_logits: return self._output[:, -1:] torch.cuda.synchronize() - return self._output + return self._output.to(self._device) - async def __forward(inputs): - """forward.""" - nonlocal swap_done, swap_in_map, swap_out_map - if swap_done: - return await self.async_forward(inputs, swap_in_map=dict(), swap_out_map=dict()) - else: - swap_done = True - return await self.async_forward(inputs, swap_in_map=swap_in_map, swap_out_map=swap_out_map) + __forward = self.async_forward async def __long_context_single_forward(new_inputs, max_seqlen: int): """One large sequence.""" @@ -485,6 +487,8 @@ async def __long_context_single_forward(new_inputs, max_seqlen: int): tmp_out['hidden_states'] = output_gather.get_output() return tmp_out + origin_inputs = inputs + # make long context inputs is_long_context = inputs.input_ids.numel() > max_prefill_token_num and not inputs.is_decoding max_seqlen = 0 @@ -507,24 +511,15 @@ async def __long_context_single_forward(new_inputs, max_seqlen: int): if not is_long_context: ret = await __forward(inputs) - if not return_logits and not inputs.is_decoding: - # fetch seq_length from the context, since models may change it (e.g. InternVL-3.5-Flash) - seq_length = ret.get('seq_length', None) - assert seq_length is not None, 'seq_length cannot be None.' - last_token_loc = seq_length.cumsum(0) - 1 - - ret['hidden_states'] = ret['hidden_states'][:, last_token_loc] else: ret = await __long_context_single_forward(inputs, max_seqlen) - if not return_logits: - last_token_loc = [-1] - ret['hidden_states'] = ret['hidden_states'][:, last_token_loc] - else: - ret['hidden_states'] = ret['hidden_states'].to('cuda') + + if not return_logits: + ret = self._postprocess_forward_output(ret, origin_inputs) # compute dummy loop if dummy_loop > 0: - dummy_inputs = ModelInputs.make_dummy(1, False, 'cuda', vocab_size=self.model_config.vocab_size) + dummy_inputs = self.inputs_strategy.make_dummy(1, False, 'cuda', vocab_size=self.model_config.vocab_size) for _ in range(dummy_loop): await __forward(dummy_inputs) @@ -533,29 +528,18 @@ async def __long_context_single_forward(new_inputs, max_seqlen: int): ret['logits'] = logits return ret - async def async_sampling_logits(self, logits: torch.Tensor, all_ids: torch.Tensor, guided_input_ids: torch.Tensor, - sampling_inputs: SamplingInputs, inputs: ModelInputs, ignore_eos: torch.Tensor): + async def async_sampling_logits(self, logits: torch.Tensor, sampling_inputs: SamplingInputs, inputs: ModelInputs): """Sampling logits.""" - def __get_last_logits(): - """Get last logits.""" - seq_length = inputs.seq_length - if len(seq_length) == logits.size(0): - return logits - - last_idx = seq_length.cumsum(-1) - 1 - return logits[last_idx, :] - # record function does not support async function # so we can not decorate it on async_sampling_logits with record_function('sampling_logits'): - split_logits = __get_last_logits() logits_processor = FusedLogitsProcessor(sampling_inputs, - ignore_eos, self.tokenizer, sampling_vocab_size=self.sampling_vocab_size, logprobs_mode=self.misc_config.logprobs_mode) - logits, raw_logprobs = await logits_processor(all_ids, guided_input_ids, split_logits) + origin_logits = logits + logits, raw_logprobs = await logits_processor(origin_logits) next_token_ids = logits_processor.sampling(logits) logprobs = logits_processor.compute_logprobs(raw_logprobs, next_token_ids) if logprobs is not None: @@ -572,17 +556,18 @@ def _push_output(self, output: BatchedOutputs): event.record() self._out_que.put_nowait((output, event)) - def _broadcast_next_token(self, next_token_ids: torch.Tensor, dist_ctx: DistContext = None): + @contextmanager + def _broadcast_next_token(self, next_token_ids: torch.Tensor, dist_ctx: DistContext = None, enable: bool = True): + if not enable: + yield + return + if dist_ctx is None: dist_ctx = get_dist_manager().current_context() - if self.cache_config.role == EngineRole.Decode: - next_token_ids = next_token_ids.cpu() - tp_cpu_group = dist_ctx.tp_cpu_group - dist.all_reduce(next_token_ids, op=dist.ReduceOp.SUM, group=tp_cpu_group) - else: - tp_gpu_group = dist_ctx.tp_gpu_group - dist.broadcast(next_token_ids, src=0, group=tp_gpu_group) - return next_token_ids + tp_gpu_group = dist_ctx.tp_gpu_group + handle = dist.broadcast(next_token_ids, src=0, group=tp_gpu_group, async_op=True) + yield + handle.wait() async def _async_step_background( self, @@ -590,38 +575,26 @@ async def _async_step_background( loop_count: int, swap_in_map: Dict = None, swap_out_map: Dict = None, - all_ids: torch.Tensor = None, - guided_input_ids: torch.Tensor = None, sampling_inputs: SamplingInputs = None, - num_appendable_ids: torch.LongTensor = None, - num_ignore_eos: torch.LongTensor = None, + stopping_criteria: StoppingCriteria = None, return_logits: bool = False, is_dummy: bool = False, sync_long_context: bool = False, + extra_inputs: ExtraInputs = None, ): """Asyc forward task.""" - if swap_in_map is None: - swap_in_map = dict() - - if swap_out_map is None: - swap_out_map = dict() - dist_ctx = get_dist_manager().current_context() @record_function('update_inputs_for_next_step') - def __update_inputs(next_token_ids, model_metas): + def __update_inputs(next_token_ids, model_metas, extra_inputs): """Update inputs.""" - nonlocal all_ids, guided_input_ids, swap_in_map, swap_out_map - swap_in_map = dict() - swap_out_map = dict() - inputs.model_metas = model_metas - inputs.update(next_token_ids) - if all_ids is not None: - all_ids = torch.cat([all_ids, next_token_ids[:, None].to(all_ids.device)], 1) - if guided_input_ids is not None: - guided_input_ids = torch.cat([guided_input_ids, next_token_ids[:, None].to(guided_input_ids.device)], 1) - if sampling_inputs.random_offsets is not None: - sampling_inputs.random_offsets += 1 + return self.agent_strategy.update_inputs_for_next_step( + inputs, + sampling_inputs, + next_token_ids=next_token_ids, + model_metas=model_metas, + extra_inputs=extra_inputs, + ) @asynccontextmanager async def __prepare_dp(): @@ -705,61 +678,78 @@ async def __prepare_dp(): logger.debug(f' rank[{rank}]: all inputs are dummy, skip forward.') return + cache_swapping(self.cache_engine, swap_in_map=swap_in_map, swap_out_map=swap_out_map) for idx in range(loop_count): # inference logger.debug(f' rank[{rank}]: model forward [{idx}].') output = await self._async_model_forward( inputs, - swap_in_map=swap_in_map, - swap_out_map=swap_out_map, return_logits=return_logits, sync_long_context=sync_long_context, ) logits = output['logits'] logits = logits[0] # [bs, seq, prob] -> [seq, prob] + seq_length = inputs.seq_length + seq_length = output.get('seq_length', inputs.seq_length) + last_logits = self._slice_outs(logits, seq_length) # [bs, 1, prob] -> [bs, prob] + extra_inputs = self.agent_strategy.slice_extra_inputs(extra_inputs, seq_length) # output empty for dummy inputs if is_dummy: continue + need_broadcast_next = (dp == 1 and tp > 1 and idx < loop_count - 1) + # sampling and stopping if need_output: logger.debug(f' rank[{rank}]: Sampling [{idx}].') # sampling - next_token_ids, logprobs = await self.async_sampling_logits(logits, all_ids, guided_input_ids, - sampling_inputs, inputs, num_ignore_eos > 0) - num_ignore_eos = num_ignore_eos - 1 + next_token_ids, logprobs = await self.async_sampling_logits(last_logits, sampling_inputs, inputs) + + with self._broadcast_next_token(next_token_ids, dist_ctx, enable=need_broadcast_next): + logger.debug(f' rank[{rank}]: synchronize token ids [{idx}]') + + # post sampling + next_token_ids, extra_inputs = self.agent_strategy.post_sampling( + inputs, last_logits, next_token_ids, extra_inputs) - # stopping criteria - stopped, num_appendable_ids = _batch_stopping_criteria(next_token_ids, sampling_inputs.stop_words, - num_appendable_ids) + # stopping criteria + stopped, stop_pos, stopping_criteria = stopping_criteria.step(next_token_ids, + sampling_inputs.stop_words, + inputs=inputs, + extra_inputs=extra_inputs) else: # Avoid adding the ADInplaceOrView dispatch key to `next_token_ids`, # as it can trigger recompilation on different ranks when using torch.compile. with torch.inference_mode(): - next_token_ids = torch.zeros_like(num_ignore_eos) + next_token_ids = inputs.input_ids.new_zeros(last_logits.size(0)) logprobs = None - # broadcast next token for TP > 1 - need_broadcast_next = (dp == 1 and tp > 1 and idx < loop_count - 1) - if need_broadcast_next: - logger.debug(f' rank[{rank}]: synchornize token ids [{idx}]') - next_token_ids = self._broadcast_next_token(next_token_ids, dist_ctx) + # broadcast next token for TP > 1 + with self._broadcast_next_token(next_token_ids, dist_ctx, enable=need_broadcast_next): + logger.debug(f' rank[{rank}]: synchronize token ids [{idx}]') + + # post sampling + next_token_ids, extra_inputs = self.agent_strategy.post_sampling(inputs, last_logits, next_token_ids, + extra_inputs) # send output model_metas = output.get('model_metas') if need_output: logger.debug(f' rank[{rank}]: Output [{idx}]') + extra_outputs = self.agent_strategy.make_extra_outputs(extra_inputs) self._push_output( BatchedOutputs(next_token_ids=next_token_ids, logits=logits if return_logits else None, stopped=stopped, + stop_pos=stop_pos, model_metas=model_metas, - logprobs=logprobs)) + logprobs=logprobs, + extra_outputs=extra_outputs)) # update for next loop if is_decoding and idx < loop_count - 1: - __update_inputs(next_token_ids, model_metas) + inputs, extra_inputs = __update_inputs(next_token_ids, model_metas, extra_inputs) async def _async_loop_background(self, forward_event: asyncio.Event = None): """Async loop background.""" @@ -787,7 +777,7 @@ async def _async_loop_background(self, forward_event: asyncio.Event = None): async def _async_loop_inputs_preprocess(self): """Async loop inputs preprocess.""" non_blocking = True - keys = ['inputs', 'all_ids', 'guided_input_ids', 'sampling_inputs', 'num_appendable_ids', 'num_ignore_eos'] + keys = ['inputs', 'sampling_inputs', 'stopping_criteria', 'extra_inputs'] while True: forward_inputs = await self._pre_in_que.get() @@ -923,7 +913,9 @@ def _build_model(self): if custom_module_map is not None: update_custom_module_map(custom_module_map) logger.debug(msg_with_rank(rank, 'build model.')) - build_model_ctx = BuildModelContext(disable_vision_encoder=self.misc_config.disable_vision_encoder) + build_model_ctx = BuildModelContext(disable_vision_encoder=self.misc_config.disable_vision_encoder, + dllm_config=self.misc_config.dllm_config, + strategy_factory=self.strategy_factory) patched_model = build_patched_model(self.model_config, device=device, model_format=self.misc_config.model_format, @@ -935,6 +927,7 @@ def _build_model(self): logger.debug(msg_with_rank(rank, 'loading adapters.')) add_adapters(patched_model, adapters, dtype=self.model_config.dtype, device=device) self.patched_model = patched_model + self.build_model_ctx = build_model_ctx def build_model(self): """Build model api.""" @@ -965,8 +958,7 @@ def build_cache_engine(self): world_size=tp, cache_stream=self.cache_stream) - def _forward_impl(self, inputs: ModelInputs, swap_in_map: SwapMap, swap_out_map: SwapMap): - cache_swapping(self.cache_engine, swap_in_map=swap_in_map, swap_out_map=swap_out_map) + def _forward_impl(self, inputs: ModelInputs): output = model_forward( self.patched_model, inputs, @@ -975,7 +967,7 @@ def _forward_impl(self, inputs: ModelInputs, swap_in_map: SwapMap, swap_out_map: ) return output - async def async_forward(self, inputs: ModelInputs, swap_in_map: SwapMap, swap_out_map: SwapMap): + async def async_forward(self, inputs: ModelInputs): """Model forward. Args: @@ -983,7 +975,7 @@ async def async_forward(self, inputs: ModelInputs, swap_in_map: SwapMap, swap_ou swap_in_map (SwapMap): Cache maps to swap in. swap_out_map (SwapMap): Cache maps to swap out. """ - output = self._forward_impl(inputs, swap_in_map=swap_in_map, swap_out_map=swap_out_map) + output = self._forward_impl(inputs) await asyncio.sleep(0) return output @@ -1091,6 +1083,7 @@ def __init__(self, model_agent: BaseModelAgent): self.model_config = model_agent.model_config self.cache_config = model_agent.cache_config self.misc_config = model_agent.misc_config + self.inputs_strategy = model_agent.inputs_strategy self.device = model_agent.device self._in_que = model_agent._in_que @@ -1106,10 +1099,10 @@ def _make_dummy_forward_inputs(self): dist_config = self.dist_ctx.dist_config batch_size = 2 if dist_config.enable_microbatch else 1 batch_size = min(self.cache_config.max_batches, batch_size) - model_inputs = ModelInputs.make_dummy(batch_size, - is_decoding, - device=self.device, - vocab_size=self.model_config.vocab_size) + model_inputs = self.inputs_strategy.make_dummy(batch_size, + is_decoding, + device=self.device, + vocab_size=self.model_config.vocab_size) forward_inputs = dict( inputs=model_inputs, loop_count=loop_count, diff --git a/lmdeploy/pytorch/kernels/cuda/flashattention.py b/lmdeploy/pytorch/kernels/cuda/flashattention.py index a22f3d36c0..8b53b84aa2 100644 --- a/lmdeploy/pytorch/kernels/cuda/flashattention.py +++ b/lmdeploy/pytorch/kernels/cuda/flashattention.py @@ -56,7 +56,8 @@ def _load_kv(ptrs, boundary_check: tl.constexpr): def _prefill_fwd_inner(acc, l_i, m_i, q, k_ptrs, v_ptrs, q1, k1_ptrs, loop_start, loop_end, sm_scale, history_mask, kv_min_loc, causal_mask: tl.constexpr, window_size: tl.constexpr, logit_softcapping: tl.constexpr, k_bound: tl.constexpr, v_bound: tl.constexpr, - shared_kv: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_DK1: tl.constexpr): + shared_kv: tl.constexpr, block_sparse_size: tl.constexpr, BLOCK_N: tl.constexpr, + BLOCK_DK1: tl.constexpr): k_ptrs = tl.advance(k_ptrs, (0, loop_start)) v_ptrs = tl.advance(v_ptrs, (loop_start, 0)) if BLOCK_DK1: @@ -77,7 +78,11 @@ def _prefill_fwd_inner(acc, l_i, m_i, q, k_ptrs, v_ptrs, q1, k1_ptrs, loop_start qk *= sm_scale qk = softcapping(qk, logit_softcapping) qk = qk * tl_log2(math.e) - qk_mask = (history_mask[:, None]) >= (start_n + offs_n[None, :]) + if block_sparse_size > 1: + offs_mask = (start_n + offs_n) // block_sparse_size * block_sparse_size + qk_mask = (history_mask[:, None]) >= offs_mask[None, :] + else: + qk_mask = (history_mask[:, None]) >= (start_n + offs_n[None, :]) if window_size > 0: qk_mask = qk_mask & ((start_n + offs_n[None, :]) >= kv_min_loc[:, None]) qk = tl.where( @@ -180,6 +185,7 @@ def _flash_prefill_fwd_kernel( window_size: tl.constexpr, logit_softcapping: tl.constexpr, shared_kv: tl.constexpr, + block_sparse_size: tl.constexpr, BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_DK: tl.constexpr, @@ -295,6 +301,7 @@ def _flash_prefill_fwd_kernel( k_bound=k_bound0, v_bound=v_bound0, shared_kv=shared_kv, + block_sparse_size=block_sparse_size, BLOCK_N=BLOCK_N, BLOCK_DK1=BLOCK_DK1) @@ -322,6 +329,7 @@ def _flash_prefill_fwd_kernel( k_bound=k_bound1, v_bound=v_bound1, shared_kv=shared_kv, + block_sparse_size=block_sparse_size, BLOCK_N=BLOCK_N, BLOCK_DK1=BLOCK_DK1) # epilogue @@ -448,6 +456,7 @@ def flash_attention_fwd( logit_softcapping: float = None, sinks: Tensor = None, causal: bool = True, + block_sparse_size: int = 1, kv_layout: str = 'hsd', ): """Varlen flash Attention forward. @@ -546,6 +555,7 @@ def grid(args): window_size=window_size, logit_softcapping=logit_softcapping, shared_kv=shared_kv, + block_sparse_size=block_sparse_size, BLOCK_DK=BLOCK_DK, BLOCK_DK1=BLOCK_DK1, BLOCK_DV=BLOCK_DV, diff --git a/lmdeploy/pytorch/kernels/cuda/pagedattention.py b/lmdeploy/pytorch/kernels/cuda/pagedattention.py index a6097e7027..a4941c36af 100644 --- a/lmdeploy/pytorch/kernels/cuda/pagedattention.py +++ b/lmdeploy/pytorch/kernels/cuda/pagedattention.py @@ -59,6 +59,7 @@ def _fwd_grouped_split_kernel( stride_od: tl.constexpr, stride_boffb, kv_group_num: tl.constexpr, + seq_len: tl.constexpr, window_size: tl.constexpr, head_size: tl.constexpr, head_size_v: tl.constexpr, @@ -74,18 +75,20 @@ def _fwd_grouped_split_kernel( ): """First step kernel of split k attention.""" cur_batch = tl.program_id(2) - cur_kv_head = tl.program_id(0) + tile_id = tl.program_id(0) split_k_id = tl.program_id(1) - if BLOCK_H < kv_group_num: - HEAD_PER_CTA: tl.constexpr = BLOCK_H - else: - HEAD_PER_CTA: tl.constexpr = kv_group_num - cur_head = cur_kv_head * HEAD_PER_CTA + tl.arange(0, BLOCK_H) - mask_h = cur_head < cur_kv_head * HEAD_PER_CTA + HEAD_PER_CTA + HEADS_PER_REQ: tl.constexpr = kv_group_num * seq_len + TILES_PER_GROUP: tl.constexpr = tl.cdiv(HEADS_PER_REQ, BLOCK_H) + subtile_id = tile_id % TILES_PER_GROUP + cur_kv_head = tile_id // TILES_PER_GROUP + offs_h = subtile_id * BLOCK_H + tl.arange(0, BLOCK_H) + cur_head = cur_kv_head * kv_group_num + offs_h % kv_group_num + cur_token = cur_batch * seq_len + offs_h // kv_group_num + + mask_h = cur_head < cur_kv_head * kv_group_num + kv_group_num + mask_h = mask_h & (cur_token < cur_batch * seq_len + seq_len) mask_h = mask_h & (cur_head < num_heads_q) - if BLOCK_H < kv_group_num: - cur_kv_head = (cur_kv_head * HEAD_PER_CTA) // kv_group_num q_seqlen = 1 kv_seqlen = tl.load(KV_seqlens + cur_batch) @@ -104,7 +107,7 @@ def _fwd_grouped_split_kernel( off_k = (cur_kv_head * stride_kh + offs_d[:, None] * stride_kd + offs_n[None, :] * stride_kbs) off_v = (cur_kv_head * stride_vh + offs_dv[None, :] * stride_vd + offs_n[:, None] * stride_vbs) - off_q = (cur_batch * stride_qbs + cur_head[:, None] * stride_qh + offs_d[None, :] * stride_qd) + off_q = (cur_token[:, None] * stride_qbs + cur_head[:, None] * stride_qh + offs_d[None, :] * stride_qd) q = tl.load(Q + off_q, mask=mask_h[:, None] & mask_d[None, :], other=0) k_ptrs = K + off_k @@ -114,7 +117,7 @@ def _fwd_grouped_split_kernel( offs_d1 = BLOCK_DMODEL + tl.arange(0, BLOCK_DMODEL1) mask_d1 = offs_d1 < head_size offs_d1 = offs_d1 % head_size - off_q1 = (cur_batch * stride_qbs + cur_head[:, None] * stride_qh + offs_d1[None, :] * stride_qd) + off_q1 = (cur_token[:, None] * stride_qbs + cur_head[:, None] * stride_qh + offs_d1[None, :] * stride_qd) q1 = tl.load(Q + off_q1, mask=mask_h[:, None] & mask_d1[None, :], other=0) off_k1 = (cur_kv_head * stride_kh + offs_d1[:, None] * stride_kd + offs_n[None, :] * stride_kbs) k1_ptrs = K + off_k1 @@ -196,11 +199,11 @@ def _fwd_grouped_split_kernel( # initialize pointers to output if loop_end > loop_start: - off_acc = (cur_batch * stride_obs + split_k_id * stride_ok + cur_head[:, None] * stride_oh + + off_acc = (cur_token[:, None] * stride_obs + split_k_id * stride_ok + cur_head[:, None] * stride_oh + offs_dv[None, :] * stride_od) tl.store(Acc_out + off_acc, acc, mask=mask_h[:, None] & mask_dv[None, :]) - off_meta = (cur_batch * stride_obs + split_k_id * stride_ok + cur_head * stride_oh + head_size_v) + off_meta = (cur_token * stride_obs + split_k_id * stride_ok + cur_head * stride_oh + head_size_v) tl.store(Acc_out + off_meta, m_i, mask=mask_h) tl.store(Acc_out + off_meta + 1, l_i, mask=mask_h) @@ -588,7 +591,9 @@ def _get_block_d(Lk): if sm_scale is None: sm_scale = 1.0 / (Lq**0.5) batch, head = kv_seqlens.shape[0], q.shape[-2] - kv_group_num = q.shape[-2] // k.shape[h_dim] + num_tokens = q.shape[-3] + num_kv_heads = k.shape[h_dim] + kv_group_num = head // num_kv_heads if sinks is not None: assert sinks.is_contiguous() @@ -601,20 +606,22 @@ def _get_block_d(Lk): 'might leads to bad performance. ' 'Please reduce `block_size`.') - is_decoding = q.shape[-3] == kv_seqlens.size(0) - assert is_decoding, 'we only support decoding paged attention.' + valid = num_tokens % batch == 0 + assert valid, 'we only support decoding paged attention.' + seq_len = num_tokens // batch BLOCK_DMODEL, BLOCK_DMODEL1, BLOCK_DV = _get_block_d(Lq) - p2_kv_group_num = triton.next_power_of_2(kv_group_num) - BLOCK_H = max(16, min(BLOCK, p2_kv_group_num)) - grid_1 = triton.cdiv(head, min(BLOCK_H, kv_group_num)) + HEADS_PER_REQ = kv_group_num * seq_len + BLOCK_H = max(16, min(BLOCK, triton.next_power_of_2(HEADS_PER_REQ))) + TILES_PER_GROUP = triton.cdiv(HEADS_PER_REQ, BLOCK_H) + grid_1 = TILES_PER_GROUP * num_kv_heads SPLIT_K = _get_split_k(q.device.index, grid_1, batch) if quant_policy != 4: - acc = q.new_empty(batch, head, SPLIT_K, Lv + 2, dtype=torch.float32) + acc = q.new_empty(num_tokens, head, SPLIT_K, Lv + 2, dtype=torch.float32) else: - acc = q.new_empty(batch, head, SPLIT_K, o.shape[-1] + 2, dtype=torch.float32) + acc = q.new_empty(num_tokens, head, SPLIT_K, o.shape[-1] + 2, dtype=torch.float32) grid = ( grid_1, @@ -704,6 +711,7 @@ def _get_block_d(Lk): stride_od=acc.stride(-1), stride_boffb=block_offsets.stride(0), kv_group_num=kv_group_num, + seq_len=seq_len, window_size=window_size, head_size=Lk, head_size_v=Lv, @@ -720,7 +728,7 @@ def _get_block_d(Lk): num_stages=num_stages) num_warps = 4 - grid = (batch, head) + grid = (num_tokens, head) if quant_policy == 4: Lv *= 2 BLOCK_DV *= 2 diff --git a/lmdeploy/pytorch/messages.py b/lmdeploy/pytorch/messages.py index 4cc559a20b..101ed62546 100644 --- a/lmdeploy/pytorch/messages.py +++ b/lmdeploy/pytorch/messages.py @@ -1,8 +1,7 @@ # Copyright (c) OpenMMLab. All rights reserved. import enum -import time from dataclasses import dataclass, field -from typing import Any, Dict, List, Optional +from typing import TYPE_CHECKING, Any, Dict, List, Optional import numpy as np from torch import Tensor @@ -14,6 +13,9 @@ from .block import LogicalTokenBlocks +if TYPE_CHECKING: + from lmdeploy.pytorch.strategies.base.sequence import SequenceStrategy + logger = get_logger('lmdeploy') # vlm input type from pipeline @@ -56,7 +58,7 @@ class SamplingParam: num_logprobs: int = -1 @classmethod - def from_gen_config(self, gen_config: GenerationConfig): + def from_gen_config(cls, gen_config: GenerationConfig): """From gen config.""" min_new_tokens = gen_config.min_new_tokens or 0 @@ -158,29 +160,33 @@ class MessageStatus(enum.Enum): MIGRATION_DONE = enum.auto() -_SEQ_COUNT = 0 - - -def _new_msg_id(): - """Get a new message id.""" - global _SEQ_COUNT - seq_id = _SEQ_COUNT - _SEQ_COUNT += 1 - return seq_id +SeqMap = Dict[int, 'SchedulerSequence'] -SeqMap = Dict[int, 'SchedulerSequence'] +@dataclass +class SequenceMeta: + """Meta data shared by all sequence.""" + block_size: int + strategy: 'SequenceStrategy' = None class SequenceManager: """Sequence manager.""" - def __init__(self) -> None: + def __init__(self, seq_meta: SequenceMeta) -> None: self._seq_map: SeqMap = dict() self._status_seq_map: Dict[MessageStatus, SeqMap] = dict() for status in MessageStatus: self._status_seq_map[status] = dict() + self.seq_meta = seq_meta + self._seq_count = 0 + + def _new_seq_id(self): + seq_id = self._seq_count + self._seq_count += 1 + return seq_id + def get_all_sequences(self): """Get all sequences.""" return self._seq_map.values() @@ -223,12 +229,23 @@ def update_sequence_status(self, seq: 'SchedulerSequence', new_status: MessageSt new_status_map[seq_id] = seq +def _to_ndarray(token_ids) -> np.ndarray: + """To ndarray.""" + if isinstance(token_ids, Tensor): + token_ids = token_ids.numpy() + elif not isinstance(token_ids, np.ndarray): + token_ids = np.array(token_ids) + if token_ids.ndim == 0: + token_ids = token_ids[None] + return token_ids + + class SchedulerSession: """Scheduler session.""" - def __init__(self, session_id: int, block_size: int, seq_manager: SequenceManager = None) -> None: + def __init__(self, session_id: int, seq_manager: SequenceManager) -> None: self.session_id = session_id - self.block_size = block_size + self.seq_meta = seq_manager.seq_meta self.status: MessageStatus = MessageStatus.RUNNING self.sequences: SeqMap = dict() self.seq_manager = seq_manager @@ -237,48 +254,38 @@ def add_sequence(self, token_ids: Tensor, sampling_param: SamplingParam = None, adapter_name: str = None, - return_logits: bool = False, multimodals: MultiModalInputs = None, input_embeddings: List[InputEmbeddings] = None, migration_request: Optional[MigrationRequest] = None, resp_cache: bool = False, preserve_cache: bool = False) -> 'SchedulerSequence': """Add a new message.""" - if isinstance(token_ids, Tensor): - token_ids = token_ids.numpy() - elif not isinstance(token_ids, np.ndarray): - token_ids = np.array(token_ids) - if token_ids.ndim == 0: - token_ids = token_ids.unsqueeze(0) if sampling_param is None: sampling_param = SamplingParam() - seq = SchedulerSequence( - seq_id=_new_msg_id(), - session=self, - history_cache=HistoryTokenIds(token_ids), - num_new_tokens=0, - sampling_param=sampling_param, - adapter_name=adapter_name, - arrive_time=time.perf_counter(), - history_embeddings=HistoryEmbeddings(input_embeddings), - history_multimodals=HistoryMultiModals(multimodals), - return_logits=return_logits, - migration_request=migration_request, - resp_cache=resp_cache, - preserve_cache=preserve_cache, + seq_id = self.seq_manager._new_seq_id() + seq = self.seq_meta.strategy.make_sequence(seq_id=seq_id, + session=self, + sampling_param=sampling_param, + adapter_name=adapter_name, + migration_request=migration_request, + resp_cache=resp_cache, + preserve_cache=preserve_cache) + seq.update_token_ids( + token_ids, + multimodals=multimodals, + embeddings=input_embeddings, + mode=UpdateTokenMode.INPUTS, ) self.sequences[seq.seq_id] = seq - if self.seq_manager is not None: - self.seq_manager.add_sequence(seq) + self.seq_manager.add_sequence(seq) return seq def remove_sequence(self, seq: 'SchedulerSequence'): """Remove sequence.""" assert seq.seq_id in self.sequences self.sequences.pop(seq.seq_id) - if self.seq_manager is not None: - self.seq_manager.remove_sequence(seq) + self.seq_manager.remove_sequence(seq) def _div_up(x, n): @@ -342,9 +349,9 @@ class HistoryTokenIds: """History token ids.""" ALLOC_SIZE = 512 - def __init__(self, token_ids: np.ndarray = None): + def __init__(self, token_ids: np.ndarray = None, dtype: np.dtype = np.int64): if token_ids is None: - self._token_ids = np.empty((self.ALLOC_SIZE, ), dtype=np.int64) + self._token_ids = np.empty((self.ALLOC_SIZE, ), dtype=dtype) self._num_real = 0 else: self._token_ids = token_ids @@ -363,6 +370,11 @@ def get_real(self): """Get logical blocks.""" return self._token_ids[:self._num_real] + def resize(self, size: int): + """Set size.""" + assert size <= self._num_real + self._num_real = size + def __setitem__(self, *args, **kwargs): """Set values.""" return self.get_real().__setitem__(*args, **kwargs) @@ -397,7 +409,7 @@ def copy(self): class HistoryMultiModals: - def __init__(self, multimodals: MultiModalInputs): + def __init__(self, multimodals: MultiModalInputs = None): if multimodals is None: multimodals = dict() self.multimodals = multimodals @@ -453,6 +465,13 @@ def get_encoder_len(self, start=0, end=-1): return out_len +class UpdateTokenMode(enum.Enum): + """Update token mode.""" + INPUTS = enum.auto() + PREFILL = enum.auto() + DECODE = enum.auto() + + @dataclass class SchedulerSequence: """Scheduler message.""" @@ -466,9 +485,8 @@ class SchedulerSequence: logical_blocks: LogicalTokenBlocks = field(default_factory=LogicalTokenBlocks) adapter_name: str = None arrive_time: float = 0.0 + output_start_pos: int = 0 meta: Any = None - return_logits: bool = False - random_offsets: int = 0 _status: MessageStatus = field(default=MessageStatus.WAITING, init=False) num_ignored_history: int = 0 model_meta: Dict[str, Any] = None @@ -483,23 +501,20 @@ class SchedulerSequence: def __post_init__(self): """Post init.""" - self._num_history_ids: int = 0 + self._seq_meta: SequenceMeta = self.session.seq_meta self._num_history_images: int = 0 - self._num_images: int = len(self.history_embeddings) + self._num_history_ids: int = 0 self._num_token_ids: int = len(self.history_cache) + # vlm + self._num_images: int = len(self.history_embeddings) self._num_history_cross: int = 0 self._num_cross: int = self.history_multimodals.get_encoder_len(0, self._num_token_ids) @property def block_size(self) -> int: """Block size.""" - return self.session.block_size - - @property - def history_len(self) -> int: - """Get history length.""" - return self._num_history_ids + return self._seq_meta.block_size @property def history_image_num(self) -> int: @@ -519,7 +534,7 @@ def session_id(self) -> int: @property def token_ids(self) -> np.ndarray: """Token ids.""" - start = self.history_len + start = self.num_history_ids end = start + self._num_token_ids return self.history_cache._token_ids[start:end] @@ -533,13 +548,24 @@ def input_embeddings(self) -> List[InputEmbeddings]: @property def history_ids(self) -> np.ndarray: """History ids.""" - return self.history_cache._token_ids[:self.history_len] + return self.history_cache._token_ids[:self.num_history_ids] @property def all_ids(self) -> np.ndarray: """Full token ids.""" return self.history_cache._token_ids[:self.num_all_ids] + @property + def valid_ids(self) -> np.ndarray: + """Valid token ids.""" + return self.history_cache._token_ids[:self.num_valid_ids] + + @property + def generated_ids(self) -> np.ndarray: + end = self.num_valid_ids + start = end - self.num_new_tokens + return self.history_cache._token_ids[start:end] + @property def num_history_ids(self): """Num history ids.""" @@ -549,6 +575,10 @@ def num_history_ids(self): def num_token_ids(self): return self._num_token_ids + @property + def num_valid_ids(self): + return self._num_history_ids + self._num_token_ids + @property def num_images(self): return self._num_images @@ -556,7 +586,7 @@ def num_images(self): @property def num_all_ids(self): """Num all tokens.""" - return self.history_len + self._num_token_ids + return self._num_history_ids + self._num_token_ids @property def num_cross(self): @@ -582,15 +612,15 @@ def seq_manager(self) -> SequenceManager: def status(self): return self._status + @property + def return_logits(self): + return self.sampling_param.out_logits + @status.setter def status(self, value: MessageStatus): self.seq_manager.update_sequence_status(self, value) self._status = value - def num_all_tokens(self): - """Num all tokens.""" - return self.num_all_ids - def num_all_cross_tokens(self): """Num of all cross tokens.""" return self._num_cross + self._num_history_cross @@ -601,79 +631,45 @@ def get_input_multimodals(self): end = self.num_all_ids return self.history_multimodals.get_datas(start, end) - def update_token_ids(self, - token_ids: Tensor, - multimodals: MultiModalInputs = None, - embeddings: List[InputEmbeddings] = None, - model_meta: Dict[str, Any] = None, - append_tokens: bool = False): - """Update token ids, old token ids will be added to history.""" - old_num_history_ids = self._num_history_ids - - # update history - if not append_tokens: - self._num_history_ids += self._num_token_ids + def record_event( + self, + event_type: EventType, + timestamp: Optional[float] = None, + ) -> None: + self.engine_events.append(EngineEvent.new_event(event_type, timestamp)) - # update history image nums + def _update_embeddings(self, embeddings: List[InputEmbeddings]): + """Update input embeddings.""" self._num_history_images += self._num_images - self._num_images = 0 - if embeddings is not None: - new_embeddings = [emb.move_position(self._num_history_ids) for emb in embeddings] - self._num_images = len(new_embeddings) - self.history_embeddings.append(new_embeddings) - - # update multimodals - if multimodals is not None: - multimodals = HistoryMultiModals.update_multimodals(multimodals, self.num_all_ids) - self.history_multimodals.add_inputs(multimodals) + if embeddings is None: + self._num_images = 0 + return + new_embeddings = [emb.move_position(self._num_history_ids) for emb in embeddings] + self._num_images = len(new_embeddings) + self.history_embeddings.append(new_embeddings) - # cross + def _update_multimodals(self, multimodals: MultiModalInputs): + """Update input multimodals.""" self._num_history_cross += self._num_cross - if multimodals is not None: - self._num_cross = self.history_multimodals.get_encoder_len(old_num_history_ids, self._num_history_ids) - else: + if multimodals is None: self._num_cross = 0 + return + multimodals = HistoryMultiModals.update_multimodals(multimodals, self.num_valid_ids) + self.history_multimodals.add_inputs(multimodals) - if model_meta is not None: - self.model_meta = model_meta - - if isinstance(token_ids, Tensor): - token_ids = token_ids.numpy() - elif not isinstance(token_ids, np.ndarray): - token_ids = np.array(token_ids) - if token_ids.ndim == 0: - token_ids = token_ids[None] - if append_tokens: - self._num_token_ids += len(token_ids) - else: - self._num_token_ids = len(token_ids) - self.history_cache.append(token_ids) - self.random_offsets += 1 - self.arrive_time = time.perf_counter() + # for mllama + self._num_cross = self.history_multimodals.get_encoder_len(self._num_history_ids, self._num_history_ids) + + def update_token_ids(self, + token_ids: Tensor, + multimodals: MultiModalInputs = None, + embeddings: List[InputEmbeddings] = None, + model_meta: Dict[str, Any] = None, + mode: UpdateTokenMode = UpdateTokenMode.INPUTS, + **kwargs): + """Update token ids, old token ids will be added to history.""" + raise NotImplementedError('NotImplemented') def set_step(self, step: int): """Set step.""" - num_all_ids = self.num_all_ids - # update step for vlm - if len(self.history_embeddings) > 0: - new_step, self._num_history_images, self._num_images = \ - self.history_embeddings.get_step(step) - assert 0 <= new_step <= step - step = new_step - self._num_history_ids = step - self._num_token_ids = num_all_ids - step - self.num_ignored_history = min(step, self.num_ignored_history) - - self.model_meta = None - - # cross - if self.history_multimodals is not None: - self._num_history_cross = self.history_multimodals.get_encoder_len(0, self.num_history_ids) - self._num_cross = self.history_multimodals.get_encoder_len(self._num_history_ids, num_all_ids) - - def record_event( - self, - event_type: EventType, - timestamp: Optional[float] = None, - ) -> None: - self.engine_events.append(EngineEvent.new_event(event_type, timestamp)) + raise NotImplementedError('NotImplemented') diff --git a/lmdeploy/pytorch/model_inputs.py b/lmdeploy/pytorch/model_inputs.py index 2afa8a0b9f..a377c9d4d6 100644 --- a/lmdeploy/pytorch/model_inputs.py +++ b/lmdeploy/pytorch/model_inputs.py @@ -1,7 +1,7 @@ # Copyright (c) OpenMMLab. All rights reserved. from contextlib import contextmanager from dataclasses import dataclass, field, fields -from typing import Any, Dict, List, Literal +from typing import TYPE_CHECKING, Any, Dict, List, Literal import torch from torch.profiler import record_function @@ -9,9 +9,12 @@ # from torch import distributed as dist import lmdeploy.pytorch.distributed as dist from lmdeploy.pytorch.backends import get_backend -from lmdeploy.pytorch.config import ModelConfig +from lmdeploy.pytorch.config import DLLMConfig, ModelConfig from lmdeploy.pytorch.multimodal.data_type import MultiModalTensor +if TYPE_CHECKING: + from lmdeploy.pytorch.strategies.base import StrategyFactoryBase + @dataclass class DPMeta: @@ -142,12 +145,14 @@ class ModelInputs: dp_meta: 'DPMeta' = None enable_microbatch: bool = False - def update(self, input_ids: torch.LongTensor): + def step(self, input_ids: torch.LongTensor, step_seqlens: torch.Tensor = None): """Update input ids.""" assert self.is_decoding - self.history_lengths = self.history_lengths + 1 - self.max_kv_seqlen += 1 - self.sum_kv_seqlen += self.seq_length.numel() + if step_seqlens is None: + step_seqlens = self.seq_length + self.history_lengths += step_seqlens + self.max_kv_seqlen += self.max_q_seqlen + self.sum_kv_seqlen += self.max_kv_seqlen * self.seq_length.numel() if input_ids.dim() == 1: input_ids = input_ids[None, :] self.input_ids = input_ids @@ -279,38 +284,6 @@ def build_dp_meta(self): """Build dp meta.""" self.dp_meta = DPMeta.build(self.input_ids.numel()) - @classmethod - @record_function('make_dummy_input') - def make_dummy(cls, - batch_size: int, - is_decoding: bool, - device: str = 'cpu', - dummy_block_id: int = 0, - vocab_size: int = 1): - """Make dummy inputs.""" - input_ids = torch.randint(0, vocab_size, ( - 1, - batch_size, - ), dtype=torch.long, device=device) - seq_length = torch.ones((batch_size, ), dtype=torch.long, device=device) - history_lengths = torch.zeros((batch_size, ), dtype=torch.long, device=device) - block_offsets = torch.full((batch_size, 1), dummy_block_id, dtype=torch.long, device=device) - num_ignored_history = torch.zeros((batch_size, ), dtype=torch.long, device=device) - local_adapter_ids = torch.zeros((batch_size, ), dtype=torch.long, device=device) - - return cls( - input_ids=input_ids, - seq_length=seq_length, - history_lengths=history_lengths, - block_offsets=block_offsets, - is_decoding=is_decoding, - num_ignored_history=num_ignored_history, - max_q_seqlen=1, - max_kv_seqlen=1, - sum_kv_seqlen=batch_size, - local_adapter_ids=local_adapter_ids, - ) - def log_info(self): """Get log info.""" ret = (f'num_tokens={self.input_ids.numel()}, batch_size={self.seq_length.numel()}' @@ -426,18 +399,27 @@ def get_mask_and_position_ids(cls, inputs: ModelInputs): """Get position ids.""" q_seqlens = inputs.seq_length history_seqlens = inputs.history_lengths + max_q_seqlen = inputs.max_q_seqlen # decoding - if inputs.is_decoding: + if max_q_seqlen == 1: attention_mask = torch.ones_like(q_seqlens)[:, None] position_ids = history_seqlens.unsqueeze(-1).clone() position_ids = position_ids.flatten() return attention_mask, position_ids num_tokens = inputs.input_ids.numel() - max_q_seqlen = inputs.max_q_seqlen + batch_size = inputs.seq_length.numel() device = q_seqlens.device + # batch with same seqlens + if max_q_seqlen * batch_size == num_tokens: + attention_mask = None + ranges = torch.arange(0, max_q_seqlen, device=device) + position_ids = history_seqlens[:, None] + ranges[None, :] + position_ids = position_ids.flatten() + return attention_mask, position_ids + # get mask mask_range = torch.arange(max_q_seqlen, device=device)[None, :] attention_mask = (mask_range < q_seqlens[:, None]).long() @@ -451,14 +433,24 @@ def get_mask_and_position_ids(cls, inputs: ModelInputs): return attention_mask, position_ids_1d +@dataclass +class BuildModelContext: + """Context for building model.""" + disable_vision_encoder: bool = False + dllm_config: DLLMConfig = None + strategy_factory: 'StrategyFactoryBase' = None + + class StepContextManager: - def __init__(self): + def __init__(self, build_ctx: BuildModelContext = None): self._current_ctx = None + build_ctx = build_ctx or BuildModelContext() + self.build_ctx = build_ctx - @staticmethod @record_function('build_step_context') def build_context( + self, inputs: ModelInputs, model_config: ModelConfig, kv_caches: List = None, diff --git a/lmdeploy/pytorch/models/module_map.py b/lmdeploy/pytorch/models/module_map.py index 0578a0ac4c..498e2c6554 100644 --- a/lmdeploy/pytorch/models/module_map.py +++ b/lmdeploy/pytorch/models/module_map.py @@ -225,4 +225,10 @@ 'GptOssForCausalLM': f'{LMDEPLOY_PYTORCH_MODEL_PATH}.gpt_oss.GptOssForCausalLM', }) +# SDAR +MODULE_MAP.update({ + 'SDARForCausalLM': f'{LMDEPLOY_PYTORCH_MODEL_PATH}.sdar.SDARForCausalLM', + 'SDARMoeForCausalLM': f'{LMDEPLOY_PYTORCH_MODEL_PATH}.sdar_moe.SDARMoeForCausalLM', +}) + CUSTOM_MODULE_MAP = dict() diff --git a/lmdeploy/pytorch/models/patch.py b/lmdeploy/pytorch/models/patch.py index 97ad271d9a..86aa5ea51f 100644 --- a/lmdeploy/pytorch/models/patch.py +++ b/lmdeploy/pytorch/models/patch.py @@ -9,9 +9,9 @@ from typing import Any, Dict import torch -from attr import dataclass from transformers.configuration_utils import PretrainedConfig +from lmdeploy.pytorch.model_inputs import BuildModelContext, StepContextManager from lmdeploy.utils import get_logger from ..config import ModelConfig @@ -188,10 +188,11 @@ def _get_model_class(config, module_map): def build_model_from_hf_config(model_config: PretrainedConfig, dtype: torch.dtype = None, device: torch.device = None, + ctx_mgr: StepContextManager = None, build_model_ctx: 'BuildModelContext' = None): """Build model from hf config.""" - from lmdeploy.pytorch.model_inputs import StepContextManager - ctx_mgr = StepContextManager() + if ctx_mgr is None: + ctx_mgr = StepContextManager(build_model_ctx) module_map = _get_module_map() if device is None: device = torch.device('cuda') @@ -333,12 +334,6 @@ def add_adapters(model: torch.nn.Module, return target_infos -@dataclass -class BuildModelContext: - """Context for building model.""" - disable_vision_encoder: bool = False - - BUILD_MODEL_CTX = BuildModelContext() diff --git a/lmdeploy/pytorch/models/sdar.py b/lmdeploy/pytorch/models/sdar.py new file mode 100644 index 0000000000..6a624e40e4 --- /dev/null +++ b/lmdeploy/pytorch/models/sdar.py @@ -0,0 +1,405 @@ +# Copyright (c) OpenMMLab. All rights reserved. + +from typing import Any, Iterable, List, Optional, Tuple + +import torch +from torch import nn +from transformers.configuration_utils import PretrainedConfig + +from lmdeploy.pytorch.model_inputs import StepContext, StepContextManager +from lmdeploy.pytorch.nn import ApplyRotaryEmb, Attention, RMSNorm, SiluAndMul, build_rotary_embedding_from_config +from lmdeploy.pytorch.nn.linear import (build_down_linear, build_gateup_linear, build_o_proj, build_qkv_proj, + build_rowwise_linear) +from lmdeploy.pytorch.weight_loader.model_weight_loader import load_weight + +from .utils.cudagraph import CudaGraphMixin + + +class SDARAttention(nn.Module): + """attention.""" + + def __init__(self, config: PretrainedConfig, dtype: torch.dtype = None, device: torch.device = None): + super().__init__() + quantization_config = getattr(config, 'quantization_config', None) + num_heads = config.num_attention_heads + num_key_value_heads = config.num_key_value_heads + hidden_size = config.hidden_size + head_dim = getattr(config, 'head_dim', hidden_size // num_heads) + num_replicate_kv_heads = getattr(config, 'num_replicate_key_value_heads', 1) + # packed qkv + # Qwen3 uses 'config.attention_bias = False' for q/k/o projections + self.qkv_proj = build_qkv_proj(hidden_size, + num_q_heads=num_heads, + num_kv_heads=num_key_value_heads, + head_size=head_dim, + bias=config.attention_bias, + quant_config=quantization_config, + dtype=dtype, + device=device, + num_replicate_kv_heads=num_replicate_kv_heads) + + # rotary embedding + self.apply_rotary_pos_emb = ApplyRotaryEmb() + dllm_block_length = config.dllm_block_length + + # attention + self.attn_fwd = Attention( + num_heads, + head_dim, + num_kv_heads=num_key_value_heads, + v_head_size=head_dim, + sliding_window=config.sliding_window, + block_sparse_size=dllm_block_length, + ) + + # o_proj + self.o_proj = build_o_proj(num_heads * head_dim, + hidden_size, + bias=config.attention_bias, + quant_config=quantization_config, + dtype=dtype, + device=device, + is_tp=True) + + # q, k norm + self.q_norm = RMSNorm(head_dim, config.rms_norm_eps, dtype=dtype, device=device) + self.k_norm = RMSNorm(head_dim, config.rms_norm_eps, dtype=dtype, device=device) + + def forward( + self, + hidden_states: torch.Tensor, + rotary_pos_emb: Tuple[torch.FloatTensor, torch.FloatTensor], + past_key_value: Optional[Tuple[torch.Tensor]] = None, + attn_metadata: Any = None, + ): + """Rewrite of LlamaAttention.forward.""" + # qkv proj + qkv_states = self.qkv_proj(hidden_states) + # (-1, heads, head_dim) + qkv_states = qkv_states.flatten(0, -2) + query_states, key_states, value_states = self.qkv_proj.split_qkv(qkv_states) + + # apply q, k norm + query_states = self.q_norm(query_states) + key_states = self.k_norm(key_states) + + # apply rotary embedding + cos, sin = rotary_pos_emb + query_states, key_states = self.apply_rotary_pos_emb( + query_states, + key_states, + cos, + sin, + ) + # attention + attn_output = self.attn_fwd( + query_states, + key_states, + value_states, + past_key_value[0], + past_key_value[1], + attn_metadata, + k_scales_zeros=None if len(past_key_value) == 2 else past_key_value[2], + v_scales_zeros=None if len(past_key_value) == 2 else past_key_value[3], + inplace=True, + ) + attn_output = attn_output.reshape(*hidden_states.shape[:-1], -1) + + # o proj + attn_output = self.o_proj(attn_output) + return attn_output + + +class SDARMLP(nn.Module): + """mlp.""" + + def __init__(self, config: PretrainedConfig, dtype: torch.dtype = None, device: torch.device = None): + super().__init__() + quantization_config = getattr(config, 'quantization_config', None) + # gate up + self.gate_up_proj = build_gateup_linear( + config.hidden_size, + [config.intermediate_size, config.intermediate_size], + bias=False, + dtype=dtype, + device=device, + quant_config=quantization_config, + is_tp=True, + ) + + # silu and mul + self.act_fn = SiluAndMul(inplace=True) + + # down + self.down_proj = build_down_linear(config.intermediate_size, + config.hidden_size, + bias=False, + quant_config=quantization_config, + dtype=dtype, + device=device, + is_tp=True) + + def forward(self, x): + """forward.""" + gate_up = self.gate_up_proj(x) + act = self.act_fn(gate_up) + return self.down_proj(act) + + +class SDARDecoderLayer(nn.Module): + """Decode layer.""" + + def __init__(self, + config: PretrainedConfig, + layer_idx: int, + dtype: torch.dtype = None, + device: torch.device = None): + super().__init__() + self.layer_idx = layer_idx + quantization_config = getattr(config, 'quantization_config', None) + + # build attention layer + self.self_attn = SDARAttention(config, dtype=dtype, device=device) + + # build MLP + self.mlp = SDARMLP(config, dtype=dtype, device=device) + + # build input layer norm + self.input_layernorm = RMSNorm(config.hidden_size, + config.rms_norm_eps, + quant_config=quantization_config, + dtype=dtype, + device=device) + + # build attention layer norm + self.post_attention_layernorm = RMSNorm(config.hidden_size, + config.rms_norm_eps, + quant_config=quantization_config, + dtype=dtype, + device=device) + + def forward( + self, + hidden_states: torch.Tensor, + rotary_pos_emb: Tuple[torch.FloatTensor, torch.FloatTensor], + past_key_value: Optional[List[torch.FloatTensor]], + residual: Optional[torch.Tensor] = None, + attn_metadata: Any = None, + ): + + if residual is None: + residual = hidden_states + hidden_states = self.input_layernorm(hidden_states) + else: + hidden_states, residual = self.input_layernorm(hidden_states, residual) + + # Self Attention + hidden_states = self.self_attn( + hidden_states=hidden_states, + rotary_pos_emb=rotary_pos_emb, + past_key_value=past_key_value, + attn_metadata=attn_metadata, + ) + + # Fully Connected + hidden_states, residual = self.post_attention_layernorm(hidden_states, residual) + hidden_states = self.mlp(hidden_states) + + outputs = (hidden_states, residual) + return outputs + + +class SDARModel(nn.Module): + """model.""" + + def __init__(self, config: PretrainedConfig, dtype: torch.dtype = None, device: torch.device = None): + super().__init__() + self.padding_idx = config.pad_token_id + self.vocab_size = config.vocab_size + + self.embed_tokens = nn.Embedding(config.vocab_size, + config.hidden_size, + self.padding_idx, + dtype=dtype, + device=device) + + # build all decode layers + self.layers = nn.ModuleList([ + SDARDecoderLayer(config, layer_idx, dtype=dtype, device=device) + for layer_idx in range(config.num_hidden_layers) + ]) + + # build norm + self.norm = RMSNorm(config.hidden_size, config.rms_norm_eps, dtype=dtype, device=device) + + # build rotary embedding + self.rotary_emb = build_rotary_embedding_from_config(config) + + def forward( + self, + input_ids: torch.LongTensor = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[List[torch.FloatTensor]] = None, + attn_metadata: Any = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + ): + """Rewrite of forward.""" + + # token embedding + if inputs_embeds is None: + inputs_embeds = self.embed_tokens(input_ids) + + hidden_states = inputs_embeds + + # rotary embedding + cos, sin = self.rotary_emb(hidden_states, position_ids) + cos, sin = cos[0], sin[0] + rotary_pos_emb = (cos, sin) + + # decoding + residual = None + for idx, decoder_layer in enumerate(self.layers): + past_key_value = past_key_values[idx] + hidden_states, residual = decoder_layer( + hidden_states, + rotary_pos_emb=rotary_pos_emb, + past_key_value=past_key_value, + residual=residual, + attn_metadata=attn_metadata, + ) + + # norm + hidden_states, _ = self.norm(hidden_states, residual) + + return hidden_states + + def get_input_embeddings(self): + """Get input embeddings.""" + return self.embed_tokens + + +class SDARForCausalLM(nn.Module, CudaGraphMixin): + """ModelForCausalLM.""" + + packed_modules_mapping = { + 'qkv_proj': [ + 'q_proj', + 'k_proj', + 'v_proj', + ], + 'gate_up_proj': [ + 'gate_proj', + 'up_proj', + ], + } + + def __init__(self, + config: PretrainedConfig, + ctx_mgr: StepContextManager, + dtype: torch.dtype = None, + device: torch.device = None): + super().__init__() + self.config = config + self.ctx_mgr = ctx_mgr + config.dllm_block_length = ctx_mgr.build_ctx.dllm_config.block_length + # build model + self.model = SDARModel(config, dtype=dtype, device=device) + # build lm_head + self.lm_head = build_rowwise_linear(config.hidden_size, + config.vocab_size, + bias=False, + dtype=dtype, + device=device) + + def forward( + self, + input_ids: torch.Tensor, + position_ids: torch.Tensor, + past_key_values: List[List[torch.Tensor]], + attn_metadata: Any = None, + inputs_embeds: torch.Tensor = None, + **kwargs, + ): + """Model forward, return logits.""" + hidden_states = self.model( + input_ids=input_ids, + position_ids=position_ids, + past_key_values=past_key_values, + attn_metadata=attn_metadata, + inputs_embeds=inputs_embeds, + ) + return hidden_states + + def get_logits(self, hidden_states: torch.Tensor): + """Compute logits of the model output.""" + return self.lm_head(hidden_states) + + def update_weights(self): + """Update weights.""" + if self.config.tie_word_embeddings: + self.lm_head.weight = self.model.embed_tokens.weight + + def get_input_embeddings(self): + """Get input embeddings.""" + return self.model.get_input_embeddings() + + def prepare_inputs_for_generation( + self, + past_key_values: List[List[torch.Tensor]], + inputs_embeds: Optional[torch.Tensor] = None, + context: StepContext = None, + ): + """Prepare input.""" + # get input_ids, position_ids and attention metadatas + input_ids = context.input_ids + position_ids = context.position_ids + attn_metadata = context.attn_metadata + + # process vision embeddings + vision_embeddings = context.input_embeddings + vision_embedding_indexing = context.input_embedding_indexing + if vision_embeddings is not None and len(vision_embeddings) > 0: + if inputs_embeds is None: + inputs_embeds = self.get_input_embeddings()(input_ids) + inputs_embeds[:, vision_embedding_indexing, :] = vision_embeddings.to(inputs_embeds) + + # inputs of forward + return dict( + input_ids=input_ids, + position_ids=position_ids, + past_key_values=past_key_values, + attn_metadata=attn_metadata, + inputs_embeds=inputs_embeds, + ) + + def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): + """Load weights.""" + # modify from vllm + stacked_params_mapping = [ + # (param_name, shard_name, shard_id) + ('.qkv_proj', '.q_proj', 'q'), + ('.qkv_proj', '.k_proj', 'k'), + ('.qkv_proj', '.v_proj', 'v'), + ('.gate_up_proj', '.gate_proj', 0), + ('.gate_up_proj', '.up_proj', 1), + ] + + params_dict = dict(self.named_parameters()) + for name, loaded_weight in weights: + if 'rotary_emb.inv_freq' in name: + continue + if ('rotary_emb.cos_cached' in name or 'rotary_emb.sin_cached' in name): + continue + if self.config.tie_word_embeddings and 'lm_head.weight' in name: + continue + + for (param_name, weight_name, shard_id) in stacked_params_mapping: + if weight_name not in name: + continue + name = name.replace(weight_name, param_name) + param = params_dict[name] + load_weight(param, loaded_weight, shard_id=shard_id) + break + else: + param = params_dict[name] + load_weight(param, loaded_weight) diff --git a/lmdeploy/pytorch/models/sdar_moe.py b/lmdeploy/pytorch/models/sdar_moe.py new file mode 100644 index 0000000000..522d2aed95 --- /dev/null +++ b/lmdeploy/pytorch/models/sdar_moe.py @@ -0,0 +1,501 @@ +# Copyright (c) OpenMMLab. All rights reserved. + +from typing import Any, Dict, Iterable, List, Optional, Tuple + +import torch +from torch import nn +from transformers.configuration_utils import PretrainedConfig + +from lmdeploy.pytorch.distributed import get_tp_world_rank +from lmdeploy.pytorch.model_inputs import StepContext, StepContextManager +from lmdeploy.pytorch.nn import ApplyRotaryEmb, Attention, RMSNorm, SiluAndMul, build_rotary_embedding_from_config +from lmdeploy.pytorch.nn.linear import (build_down_linear, build_gateup_linear, build_o_proj, build_qkv_proj, + build_rowwise_linear) +from lmdeploy.pytorch.nn.moe import SoftmaxTopK, build_fused_moe +from lmdeploy.pytorch.weight_loader.model_weight_loader import load_weight + +from .utils.cudagraph import CudaGraphMixin + + +class SDARMoeAttention(nn.Module): + """attention.""" + + def __init__(self, config: PretrainedConfig, dtype: torch.dtype = None, device: torch.device = None): + super().__init__() + quantization_config = getattr(config, 'quantization_config', None) + num_heads = config.num_attention_heads + num_key_value_heads = config.num_key_value_heads + hidden_size = config.hidden_size + head_dim = getattr(config, 'head_dim', hidden_size // num_heads) + num_replicate_kv_heads = getattr(config, 'num_replicate_key_value_heads', 1) + # packed qkv + # Qwen3 uses 'config.attention_bias = False' for q/k/o projections + self.qkv_proj = build_qkv_proj(hidden_size, + num_q_heads=num_heads, + num_kv_heads=num_key_value_heads, + head_size=head_dim, + bias=config.attention_bias, + quant_config=quantization_config, + dtype=dtype, + device=device, + num_replicate_kv_heads=num_replicate_kv_heads) + + # rotary embedding + self.apply_rotary_pos_emb = ApplyRotaryEmb() + dllm_block_length = config.dllm_block_length + + # attention + self.attn_fwd = Attention( + num_heads, + head_dim, + num_kv_heads=num_key_value_heads, + v_head_size=head_dim, + sliding_window=config.sliding_window, + block_sparse_size=dllm_block_length, + ) + + # o_proj + self.o_proj = build_o_proj(num_heads * head_dim, + hidden_size, + bias=config.attention_bias, + quant_config=quantization_config, + dtype=dtype, + device=device, + is_tp=True) + + # q, k norm + self.q_norm = RMSNorm(head_dim, config.rms_norm_eps, dtype=dtype, device=device) + self.k_norm = RMSNorm(head_dim, config.rms_norm_eps, dtype=dtype, device=device) + + def forward( + self, + hidden_states: torch.Tensor, + rotary_pos_emb: Tuple[torch.FloatTensor, torch.FloatTensor], + past_key_value: Optional[Tuple[torch.Tensor]] = None, + attn_metadata: Any = None, + ): + """Rewrite of LlamaAttention.forward.""" + # qkv proj + qkv_states = self.qkv_proj(hidden_states) + # (-1, heads, head_dim) + qkv_states = qkv_states.flatten(0, -2) + query_states, key_states, value_states = self.qkv_proj.split_qkv(qkv_states) + + # apply q, k norm + query_states = self.q_norm(query_states) + key_states = self.k_norm(key_states) + + # apply rotary embedding + cos, sin = rotary_pos_emb + query_states, key_states = self.apply_rotary_pos_emb( + query_states, + key_states, + cos, + sin, + ) + # attention + attn_output = self.attn_fwd( + query_states, + key_states, + value_states, + past_key_value[0], + past_key_value[1], + attn_metadata, + k_scales_zeros=None if len(past_key_value) == 2 else past_key_value[2], + v_scales_zeros=None if len(past_key_value) == 2 else past_key_value[3], + inplace=True, + ) + attn_output = attn_output.reshape(*hidden_states.shape[:-1], -1) + + # o proj + attn_output = self.o_proj(attn_output) + return attn_output + + +class SDARMoeMLP(nn.Module): + """mlp.""" + + def __init__(self, + config: PretrainedConfig, + intermediate_size: int, + dtype: torch.dtype = None, + device: torch.device = None): + super().__init__() + quantization_config = getattr(config, 'quantization_config', None) + # gate up + self.gate_up_proj = build_gateup_linear( + config.hidden_size, + [intermediate_size, intermediate_size], + bias=False, + dtype=dtype, + device=device, + quant_config=quantization_config, + is_tp=True, + ) + + # silu and mul + self.act_fn = SiluAndMul(inplace=True) + + # down + self.down_proj = build_down_linear(intermediate_size, + config.hidden_size, + bias=False, + quant_config=quantization_config, + dtype=dtype, + device=device, + is_tp=True) + + def forward(self, x): + """forward.""" + gate_up = self.gate_up_proj(x) + act = self.act_fn(gate_up) + return self.down_proj(act) + + +class SDARMoeSparseMoeBlock(nn.Module): + """SDARMoeSparseMoeBlock.""" + + def __init__(self, + config: PretrainedConfig, + layer_idx: int, + dtype: torch.dtype = None, + device: torch.device = None): + super().__init__() + self.config = config + quantization_config = getattr(config, 'quantization_config', None) + self.hidden_dim = config.hidden_size + self.ffn_dim = config.moe_intermediate_size + self.num_experts = config.num_experts + self.top_k = config.num_experts_per_tok + self.renormalize = config.norm_topk_prob + + self.gate = build_rowwise_linear( + self.hidden_dim, + self.num_experts, + bias=False, + dtype=dtype, + device=device, + is_tp=False, + ) + + self.softmax_topk = SoftmaxTopK(self.top_k) + + world_size, _ = get_tp_world_rank() + _all_reduce = world_size > 1 + self.experts = build_fused_moe( + self.hidden_dim, + self.ffn_dim, + self.num_experts, + top_k=self.top_k, + renormalize=self.renormalize, + dtype=dtype, + device=device, + quant_config=quantization_config, + all_reduce=_all_reduce, + layer_idx=layer_idx, + ) + + def forward(self, hidden_states: torch.Tensor): + """forward.""" + batch_size, sequence_length, hidden_dim = hidden_states.shape + hidden_states = hidden_states.view(-1, hidden_dim) + router_logits = self.gate(hidden_states) + topk_weights, topk_ids = self.softmax_topk(router_logits) + out_states = self.experts( + hidden_states, + topk_weights, + topk_ids, + ) + + out_states = out_states.reshape(batch_size, sequence_length, -1) + return out_states + + +class SDARMoeDecoderLayer(nn.Module): + """Decode layer.""" + + def __init__(self, + config: PretrainedConfig, + layer_idx: int, + dtype: torch.dtype = None, + device: torch.device = None): + super().__init__() + self.layer_idx = layer_idx + quantization_config = getattr(config, 'quantization_config', None) + + # build attention layer + self.self_attn = SDARMoeAttention(config, dtype=dtype, device=device) + + # build MLP + if (layer_idx not in config.mlp_only_layers) and (config.num_experts > 0 and + (layer_idx + 1) % config.decoder_sparse_step == 0): + self.mlp = SDARMoeSparseMoeBlock(config, layer_idx, dtype=dtype, device=device) + else: + self.mlp = SDARMoeMLP(config, intermediate_size=config.intermediate_size, dtype=dtype, device=device) + + # build input layer norm + self.input_layernorm = RMSNorm(config.hidden_size, + config.rms_norm_eps, + quant_config=quantization_config, + dtype=dtype, + device=device) + + # build attention layer norm + self.post_attention_layernorm = RMSNorm(config.hidden_size, + config.rms_norm_eps, + quant_config=quantization_config, + dtype=dtype, + device=device) + + def forward( + self, + hidden_states: torch.Tensor, + rotary_pos_emb: Tuple[torch.FloatTensor, torch.FloatTensor], + past_key_value: Optional[List[torch.FloatTensor]], + residual: Optional[torch.Tensor] = None, + attn_metadata: Any = None, + ): + + if residual is None: + residual = hidden_states + hidden_states = self.input_layernorm(hidden_states) + else: + hidden_states, residual = self.input_layernorm(hidden_states, residual) + + # Self Attention + hidden_states = self.self_attn( + hidden_states=hidden_states, + rotary_pos_emb=rotary_pos_emb, + past_key_value=past_key_value, + attn_metadata=attn_metadata, + ) + + # Fully Connected + hidden_states, residual = self.post_attention_layernorm(hidden_states, residual) + hidden_states = self.mlp(hidden_states) + + outputs = (hidden_states, residual) + return outputs + + +class SDARMoeModel(nn.Module): + """SDAR model.""" + + def __init__(self, config: PretrainedConfig, dtype: torch.dtype = None, device: torch.device = None): + super().__init__() + self.padding_idx = config.pad_token_id + self.vocab_size = config.vocab_size + + self.embed_tokens = nn.Embedding(config.vocab_size, + config.hidden_size, + self.padding_idx, + dtype=dtype, + device=device) + + # build all decode layers + self.layers = nn.ModuleList([ + SDARMoeDecoderLayer(config, layer_idx, dtype=dtype, device=device) + for layer_idx in range(config.num_hidden_layers) + ]) + + # build norm + self.norm = RMSNorm(config.hidden_size, config.rms_norm_eps, dtype=dtype, device=device) + + # build rotary embedding + self.rotary_emb = build_rotary_embedding_from_config(config) + + def forward( + self, + input_ids: torch.LongTensor = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[List[torch.FloatTensor]] = None, + attn_metadata: Any = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + ): + """Rewrite of forward.""" + + # token embedding + if inputs_embeds is None: + inputs_embeds = self.embed_tokens(input_ids) + + hidden_states = inputs_embeds + + # rotary embedding + cos, sin = self.rotary_emb(hidden_states, position_ids) + cos, sin = cos[0], sin[0] + rotary_pos_emb = (cos, sin) + + # decoding + residual = None + for idx, decoder_layer in enumerate(self.layers): + past_key_value = past_key_values[idx] + hidden_states, residual = decoder_layer( + hidden_states, + rotary_pos_emb=rotary_pos_emb, + past_key_value=past_key_value, + residual=residual, + attn_metadata=attn_metadata, + ) + + # norm + hidden_states, _ = self.norm(hidden_states, residual) + + return hidden_states + + def get_input_embeddings(self): + """Get input embeddings.""" + return self.embed_tokens + + +class SDARMoeForCausalLM(nn.Module, CudaGraphMixin): + """ModelForCausalLM.""" + + packed_modules_mapping = { + 'qkv_proj': [ + 'q_proj', + 'k_proj', + 'v_proj', + ], + 'gate_up_proj': [ + 'gate_proj', + 'up_proj', + ], + } + + def __init__(self, + config: PretrainedConfig, + ctx_mgr: StepContextManager, + dtype: torch.dtype = None, + device: torch.device = None): + super().__init__() + self.config = config + self.ctx_mgr = ctx_mgr + config.dllm_block_length = ctx_mgr.build_ctx.dllm_config.block_length + # build model + self.model = SDARMoeModel(config, dtype=dtype, device=device) + # build lm_head + self.lm_head = build_rowwise_linear(config.hidden_size, + config.vocab_size, + bias=False, + dtype=dtype, + device=device) + + def forward( + self, + input_ids: torch.Tensor, + position_ids: torch.Tensor, + past_key_values: List[List[torch.Tensor]], + attn_metadata: Any = None, + inputs_embeds: torch.Tensor = None, + **kwargs, + ): + """Model forward, return logits.""" + hidden_states = self.model( + input_ids=input_ids, + position_ids=position_ids, + past_key_values=past_key_values, + attn_metadata=attn_metadata, + inputs_embeds=inputs_embeds, + ) + return hidden_states + + def get_logits(self, hidden_states: torch.Tensor): + """Compute logits of the model output.""" + return self.lm_head(hidden_states) + + def update_weights(self): + """Update weights.""" + if self.config.tie_word_embeddings: + self.lm_head.weight = self.model.embed_tokens.weight + + def get_input_embeddings(self): + """Get input embeddings.""" + return self.model.get_input_embeddings() + + def prepare_inputs_for_generation( + self, + past_key_values: List[List[torch.Tensor]], + inputs_embeds: Optional[torch.Tensor] = None, + context: StepContext = None, + ): + """Prepare input.""" + # get input_ids, position_ids and attention metadatas + input_ids = context.input_ids + position_ids = context.position_ids + attn_metadata = context.attn_metadata + + # process vision embeddings + vision_embeddings = context.input_embeddings + vision_embedding_indexing = context.input_embedding_indexing + if vision_embeddings is not None and len(vision_embeddings) > 0: + if inputs_embeds is None: + inputs_embeds = self.get_input_embeddings()(input_ids) + inputs_embeds[:, vision_embedding_indexing, :] = vision_embeddings.to(inputs_embeds) + + # inputs of forward + return dict( + input_ids=input_ids, + position_ids=position_ids, + past_key_values=past_key_values, + attn_metadata=attn_metadata, + inputs_embeds=inputs_embeds, + ) + + def _load_weight_experts(self, name: str, loaded_weight: torch.Tensor, params_dict: Dict[str, nn.Parameter], + expert_params_mapping: List): + """Load weight experts.""" + # load fused weights + for (param_name, weight_name, expert_id, shard_id) in expert_params_mapping: + if weight_name not in name: + continue + name = name.replace(weight_name, param_name) + param = params_dict[name] + load_weight(param, loaded_weight, expert_id=expert_id, shard_id=shard_id) + break + else: + param = params_dict[name] + load_weight(param, loaded_weight) + + def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): + """Load weights.""" + # modify from vllm + stacked_params_mapping = [ + # (param_name, shard_name, shard_id) + ('.qkv_proj', '.q_proj', 'q'), + ('.qkv_proj', '.k_proj', 'k'), + ('.qkv_proj', '.v_proj', 'v'), + ('.gate_up_proj', '.gate_proj', 0), + ('.gate_up_proj', '.up_proj', 1), + ] + + # expert map + num_experts = self.config.num_experts + expert_params_mapping = [] + for exp_id in range(num_experts): + gate_param = ('.experts.gate_up', f'.experts.{exp_id}.gate_proj', exp_id, 'gate') + up_param = ('.experts.gate_up', f'.experts.{exp_id}.up_proj', exp_id, 'up') + down_param = ('.experts.down', f'.experts.{exp_id}.down_proj', exp_id, 'down') + expert_params_mapping += [gate_param, up_param, down_param] + + params_dict = dict(self.named_parameters()) + for name, loaded_weight in weights: + if 'rotary_emb.inv_freq' in name: + continue + if ('rotary_emb.cos_cached' in name or 'rotary_emb.sin_cached' in name): + continue + if self.config.tie_word_embeddings and 'lm_head.weight' in name: + continue + + if '.experts' in name: + self._load_weight_experts(name, loaded_weight, params_dict, expert_params_mapping=expert_params_mapping) + else: + for (param_name, weight_name, shard_id) in stacked_params_mapping: + if weight_name not in name: + continue + name = name.replace(weight_name, param_name) + param = params_dict[name] + load_weight(param, loaded_weight, shard_id=shard_id) + break + else: + param = params_dict[name] + load_weight(param, loaded_weight) diff --git a/lmdeploy/pytorch/models/utils/cudagraph.py b/lmdeploy/pytorch/models/utils/cudagraph.py index b86adbbb89..9901bf443c 100644 --- a/lmdeploy/pytorch/models/utils/cudagraph.py +++ b/lmdeploy/pytorch/models/utils/cudagraph.py @@ -142,15 +142,15 @@ def fill_buffers_cudagraph(self, graph_meta: CudaGraphMeta, input_ids: Tensor, p new_inputs['cross_attn_metadata'] = cross_attn_metadata if is_decoding: - new_inputs['input_ids'] = input_buffers['input_ids'][:, :new_batch_size] - new_inputs['position_ids'] = input_buffers['position_ids'][:, :new_batch_size] + new_inputs['input_ids'] = input_buffers['input_ids'][:, :num_tokens] + new_inputs['position_ids'] = input_buffers['position_ids'][:, :num_tokens] else: new_inputs['input_ids'] = input_buffers['input_ids'] new_inputs['position_ids'] = input_buffers['position_ids'] if inputs_embeds is not None: if is_decoding: - new_inputs['inputs_embeds'] = input_buffers['inputs_embeds'][:, :new_batch_size] + new_inputs['inputs_embeds'] = input_buffers['inputs_embeds'][:, :num_tokens] else: new_inputs['inputs_embeds'] = input_buffers['inputs_embeds'] diff --git a/lmdeploy/pytorch/nn/attention.py b/lmdeploy/pytorch/nn/attention.py index ef5d831d51..7a1654db4b 100644 --- a/lmdeploy/pytorch/nn/attention.py +++ b/lmdeploy/pytorch/nn/attention.py @@ -37,6 +37,7 @@ def __init__( causal: bool = True, use_flash_mla: bool = False, learnable_sink: bool = False, + block_sparse_size: int = 1, **kwargs, ): super().__init__() @@ -61,6 +62,7 @@ def __init__( causal=causal, use_flash_mla=use_flash_mla, learnable_sink=learnable_sink, + block_sparse_size=block_sparse_size, **kwargs, ) diff --git a/lmdeploy/pytorch/paging/block_manager/default_block_manager.py b/lmdeploy/pytorch/paging/block_manager/default_block_manager.py index cc41187900..3348f2f8a5 100644 --- a/lmdeploy/pytorch/paging/block_manager/default_block_manager.py +++ b/lmdeploy/pytorch/paging/block_manager/default_block_manager.py @@ -25,7 +25,7 @@ class DefaultBlockManager(BaseBlockManager): @classmethod def num_required_blocks(cls, obj: SchedulerSequence, prealloc_size: int = 0): """Get num required blocks.""" - num_tokens = obj.num_all_tokens() + prealloc_size + num_tokens = obj.num_all_ids + prealloc_size # cross tokens num_cross = obj.num_all_cross_tokens() diff --git a/lmdeploy/pytorch/paging/block_manager/window_block_manager.py b/lmdeploy/pytorch/paging/block_manager/window_block_manager.py index 3f5ca9605e..1d91c020cb 100644 --- a/lmdeploy/pytorch/paging/block_manager/window_block_manager.py +++ b/lmdeploy/pytorch/paging/block_manager/window_block_manager.py @@ -10,10 +10,10 @@ def _num_blocks_to_drop(seq: SchedulerSequence, window_size: int): """Num blocks to free.""" - if seq.history_len <= window_size: + history_len = seq.num_history_ids + if seq.num_history_ids <= window_size: return 0 block_size = seq.block_size - history_len = seq.history_len num_blocks = len(seq.logical_blocks) win_start_block_id = (history_len - window_size) // block_size win_end_block_id = (history_len - 1) // block_size diff --git a/lmdeploy/pytorch/paging/block_trie.py b/lmdeploy/pytorch/paging/block_trie.py index b5908dde54..d00690af24 100644 --- a/lmdeploy/pytorch/paging/block_trie.py +++ b/lmdeploy/pytorch/paging/block_trie.py @@ -81,7 +81,7 @@ def __match_success(node: Node): curr = node num_matched += block_size - while num_matched + block_size < seq.num_all_ids: + while num_matched + block_size < seq.num_valid_ids: curr_tokens = seq.history_cache[num_matched:num_matched + block_size] key = hash(('random', tuple(curr_tokens))) @@ -116,9 +116,9 @@ def allocate(self, seq: SchedulerSequence): logical_blocks.last_shared_node = node num_matched = node.num_matched - num_all_ids = seq.num_all_ids + num_valid_ids = seq.num_valid_ids - if num_matched + block_size > num_all_ids: + if num_matched + block_size > num_valid_ids: return if len(node.children) == 0 and node.parent is not None: @@ -127,7 +127,7 @@ def allocate(self, seq: SchedulerSequence): block_id = num_matched // block_size blocks = [] free_blocks = [] - while num_matched + block_size <= num_all_ids: + while num_matched + block_size <= num_valid_ids: curr_tokens = seq.history_cache[num_matched:num_matched + block_size] block = logical_blocks[block_id] diff --git a/lmdeploy/pytorch/paging/scheduler.py b/lmdeploy/pytorch/paging/scheduler.py index 0be35ab9be..8144ac52a1 100644 --- a/lmdeploy/pytorch/paging/scheduler.py +++ b/lmdeploy/pytorch/paging/scheduler.py @@ -9,7 +9,7 @@ from lmdeploy.utils import get_logger, logging_timer from ..config import CacheConfig, SchedulerConfig -from ..messages import MessageStatus, SchedulerSequence, SchedulerSession, SequenceManager +from ..messages import MessageStatus, SchedulerSequence, SchedulerSession, SequenceManager, SequenceMeta from .block_manager import build_block_manager from .block_trie import BlockTrie @@ -36,7 +36,10 @@ class Scheduler: cache_config (CacheConfig): The config of cache info. """ - def __init__(self, scheduler_config: SchedulerConfig, cache_config: CacheConfig) -> None: + def __init__(self, + scheduler_config: SchedulerConfig, + cache_config: CacheConfig, + seq_meta: SequenceMeta = None) -> None: self.scheduler_config = scheduler_config self.cache_config = cache_config @@ -50,7 +53,8 @@ def __init__(self, scheduler_config: SchedulerConfig, cache_config: CacheConfig) self.eviction_helper = self.build_eviction_helper(self.scheduler_config.eviction_type) - self.seq_manager = SequenceManager() + seq_meta = seq_meta or SequenceMeta(self.cache_config.block_size) + self.seq_manager = SequenceManager(seq_meta) @property def waiting(self): @@ -121,7 +125,7 @@ def add_session(self, session_id: int): session_id (int): New session id. """ assert session_id not in self.sessions - session = SchedulerSession(session_id, self.cache_config.block_size, seq_manager=self.seq_manager) + session = SchedulerSession(session_id, seq_manager=self.seq_manager) self.sessions[session_id] = session return session @@ -179,7 +183,7 @@ def _reorder_migrating(): return running_migration @logging_timer('SchedulePrefilling', logger) - def _schedule_prefill(self): + def _schedule_prefill(self, prealloc_size: int = 0): """Schedule for prefilling.""" max_batches = self.scheduler_config.max_batches - self.num_running() - self.num_locked() @@ -203,7 +207,7 @@ def __evict_for_seq(seq: SchedulerSequence, waiting): hanging = reversed(self.hanging) waiting = reversed(waiting) evictable = list(chain(hanging, waiting)) - return eviction_helper.evict_for_seq(seq, evictable, 0) + return eviction_helper.evict_for_seq(seq, evictable, prealloc_size) def _reorder_waiting(): """Reorder waiting.""" @@ -226,7 +230,7 @@ def _reorder_waiting(): break # allocate session memory - self.block_manager.allocate(seq) + self.block_manager.allocate(seq, prealloc_size) _to_running(seq) seq.record_event(EventType.SCHEDULED) @@ -287,7 +291,7 @@ def __evict_for_seq(seq: SchedulerSequence, num_required_blocks: int): def schedule(self, is_prefill: bool, prealloc_size: int = 0): """Schedule inputs for next steps.""" if is_prefill: - output = self._schedule_prefill() + output = self._schedule_prefill(0) else: output = self._schedule_decoding(prealloc_size) running, swap_in_map, swap_out_map, copy_map = output diff --git a/lmdeploy/pytorch/strategies/__init__.py b/lmdeploy/pytorch/strategies/__init__.py new file mode 100644 index 0000000000..c0f5da1262 --- /dev/null +++ b/lmdeploy/pytorch/strategies/__init__.py @@ -0,0 +1,16 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from lmdeploy.pytorch.config import MiscConfig, ModelConfig + + +def build_strategy_factory(model_config: ModelConfig, misc_config: MiscConfig): + """Build strategy factory.""" + model_paradigm = model_config.model_paradigm + + if model_paradigm == 'ar': + from .ar import ARStrategyFactory + return ARStrategyFactory(model_config=model_config) + elif model_paradigm == 'dllm': + from .dllm import DLLMStrategyFactory + return DLLMStrategyFactory(model_config=model_config, dllm_config=misc_config.dllm_config) + else: + raise RuntimeError(f'Unsupported model paradigm: {model_paradigm}') diff --git a/lmdeploy/pytorch/strategies/ar/__init__.py b/lmdeploy/pytorch/strategies/ar/__init__.py new file mode 100644 index 0000000000..b593107c2e --- /dev/null +++ b/lmdeploy/pytorch/strategies/ar/__init__.py @@ -0,0 +1,54 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from typing import TYPE_CHECKING + +from lmdeploy.pytorch.config import ModelConfig +from lmdeploy.pytorch.strategies.base.sequence import SequenceStrategy + +if TYPE_CHECKING: + from lmdeploy.pytorch.strategies.base.cudagraph import CudagraphStrategy + from lmdeploy.pytorch.strategies.base.model_inputs import ModelInputsStrategy + from lmdeploy.pytorch.strategies.base.sampling import SamplingStrategy + from lmdeploy.pytorch.strategies.base.model_agent import ModelAgentStrategy + from lmdeploy.pytorch.strategies.base.engine import EngineStrategy + from lmdeploy.pytorch.config import CacheConfig, SchedulerConfig + +from ..base import StrategyFactoryBase + + +class ARStrategyFactory(StrategyFactoryBase): + + def __init__(self, model_config: ModelConfig): + """config.""" + self.model_config = model_config + + def build_cudagraph_strategy(self) -> 'CudagraphStrategy': + """Build cudagraph strategy.""" + from .cudagraph import ARCudagraphStrategy + return ARCudagraphStrategy() + + def build_sampling_strategy(self) -> 'SamplingStrategy': + """Build sampling strategy.""" + from .sampling import ARSamplingStrategy + pad_token_id = self.model_config.bos_token_id + pad_token_id = 0 if pad_token_id is None else pad_token_id + return ARSamplingStrategy(pad_token_id) + + def build_model_inputs_strategy(self) -> 'ModelInputsStrategy': + """Build model inputs strategy.""" + from .model_inputs import ARModelInputsStrategy + return ARModelInputsStrategy() + + def build_model_agent_strategy(self) -> 'ModelAgentStrategy': + """Build model agent strategy.""" + from .model_agent import ARModelAgentStrategy + return ARModelAgentStrategy() + + def build_engine_strategy(self, cache_config: 'CacheConfig', + scheduler_config: 'SchedulerConfig') -> 'EngineStrategy': + """Build engine strategy.""" + from .engine import AREngineStrategy + return AREngineStrategy(cache_config=cache_config, scheduler_config=scheduler_config) + + def build_sequence_strategy(self) -> SequenceStrategy: + from .sequence import ARSequenceStrategy + return ARSequenceStrategy() diff --git a/lmdeploy/pytorch/strategies/ar/cudagraph.py b/lmdeploy/pytorch/strategies/ar/cudagraph.py new file mode 100644 index 0000000000..e3749bcfc2 --- /dev/null +++ b/lmdeploy/pytorch/strategies/ar/cudagraph.py @@ -0,0 +1,9 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from ..base.cudagraph import CudagraphStrategy + + +class ARCudagraphStrategy(CudagraphStrategy): + + def get_max_tokens(self, batch_size: int) -> int: + """Get max tokens.""" + return batch_size diff --git a/lmdeploy/pytorch/strategies/ar/engine.py b/lmdeploy/pytorch/strategies/ar/engine.py new file mode 100644 index 0000000000..60ae69f8c9 --- /dev/null +++ b/lmdeploy/pytorch/strategies/ar/engine.py @@ -0,0 +1,20 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from lmdeploy.pytorch.config import CacheConfig, SchedulerConfig + +from ..base.engine import EngineStrategy + + +class AREngineStrategy(EngineStrategy): + """AR Engine Strategy.""" + + def __init__(self, scheduler_config: SchedulerConfig, cache_config: CacheConfig) -> None: + self.scheduler_config = scheduler_config + self.cache_config = cache_config + + def get_prealloc_size(self, is_decoding: bool): + """Get prealloc_size.""" + return self.scheduler_config.prefill_interval if is_decoding else 0 + + def get_num_loops(self, is_decoding: bool) -> int: + """Get num_loops.""" + return self.scheduler_config.prefill_interval if is_decoding else 1 diff --git a/lmdeploy/pytorch/strategies/ar/model_agent.py b/lmdeploy/pytorch/strategies/ar/model_agent.py new file mode 100644 index 0000000000..4096db2cb7 --- /dev/null +++ b/lmdeploy/pytorch/strategies/ar/model_agent.py @@ -0,0 +1,108 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from dataclasses import dataclass +from typing import Any, List, Optional + +import torch +from torch.profiler import record_function + +from lmdeploy.pytorch.engine.logits_process import SamplingInputs +from lmdeploy.pytorch.messages import SchedulerSequence +from lmdeploy.pytorch.model_inputs import ModelInputs + +from ..base.model_agent import ExtraInputs, ExtraOutputs, ModelAgentStrategy, StoppingCriteria + +SeqList = List[SchedulerSequence] + + +class ARExtraInputs(ExtraInputs): + """Ar extra inputs.""" + + +class ARExtraOutputs(ExtraOutputs): + """Ar extra outputs.""" + + +@dataclass +class ARStoppingCriteria(StoppingCriteria): + num_appendable_ids: torch.Tensor + + @record_function('stopping_criteria') + def step(self, + token_ids: torch.Tensor, + stop_words: torch.Tensor, + inputs: Optional[ModelInputs] = None, + extra_inputs: Optional[ARExtraInputs] = None): + """Check whether to stop generation.""" + num_appendable_ids = self.num_appendable_ids - 1 + stopped = num_appendable_ids <= 0 + stop_pos = torch.zeros_like(num_appendable_ids) + if stop_words is not None: + sw_stopped = (token_ids[:, None] == stop_words).any(1) + stopped = stopped | sw_stopped + one_ids = torch.clamp_max(num_appendable_ids, 0) + num_appendable_ids = torch.where(sw_stopped, one_ids, num_appendable_ids) + + # I don't know why assign inplace does not works... + new_stopping = ARStoppingCriteria(num_appendable_ids=num_appendable_ids) + return stopped, stop_pos, new_stopping + + +class ARModelAgentStrategy(ModelAgentStrategy): + + def slice_outputs(self, inputs: torch.Tensor, seq_length: torch.LongTensor) -> torch.Tensor: + """Slice outputs.""" + # batch size == 1 + if len(seq_length) == 1: + return inputs[-1:] + + if len(seq_length) == inputs.size(0): + return inputs + last_idx = seq_length.cumsum(-1) - 1 + return inputs[last_idx] + + def slice_extra_inputs(self, extra_inputs: ARExtraInputs, seq_length: torch.LongTensor) -> ARExtraInputs: + """Slice outputs.""" + return extra_inputs + + def _step_sampling_inputs(self, sampling_inputs: SamplingInputs, next_token_ids: torch.Tensor): + """step.""" + sampling_inputs.num_ignore_eos = sampling_inputs.num_ignore_eos - 1 + + all_ids = sampling_inputs.all_ids + if all_ids is not None: + sampling_inputs.all_ids = torch.cat([all_ids, next_token_ids[:, None]], 1) + + guided_input_ids = sampling_inputs.guided_input_ids + if guided_input_ids is not None: + sampling_inputs.guided_input_ids = torch.cat([guided_input_ids, next_token_ids[:, None]], 1) + + return sampling_inputs + + def make_stopping_criteria(self, seqs: SeqList) -> ARStoppingCriteria: + """Create stopping criteria.""" + num_appendable = [seq.sampling_param.max_new_tokens - seq.num_new_tokens for seq in seqs] + num_appendable = torch.tensor(num_appendable) + return ARStoppingCriteria(num_appendable_ids=num_appendable) + + def make_extra_inputs(self, seqs: 'SeqList') -> ExtraInputs: + """Create extra inputs.""" + return ARExtraInputs() + + def make_extra_outputs(self, extra_inputs: ARExtraInputs) -> ARExtraOutputs: + """Create extra outputs.""" + return ARExtraOutputs() + + def update_inputs_for_next_step(self, model_inputs: 'ModelInputs', sampling_inputs: 'SamplingInputs', + next_token_ids: torch.Tensor, model_metas: Any, extra_inputs: ARExtraInputs, + **kwargs): + """Step next inputs.""" + model_inputs.model_metas = model_metas + step_seqlens = model_inputs.seq_length + model_inputs.step(next_token_ids, step_seqlens) + self._step_sampling_inputs(sampling_inputs, next_token_ids) + return model_inputs, extra_inputs + + def post_sampling(self, inputs: 'ModelInputs', logits: torch.Tensor, next_token_ids: torch.LongTensor, + extra_inputs: ARExtraInputs): + """Post sampling.""" + return next_token_ids, extra_inputs diff --git a/lmdeploy/pytorch/strategies/ar/model_inputs.py b/lmdeploy/pytorch/strategies/ar/model_inputs.py new file mode 100644 index 0000000000..5b263b1e37 --- /dev/null +++ b/lmdeploy/pytorch/strategies/ar/model_inputs.py @@ -0,0 +1,21 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from lmdeploy.pytorch.model_inputs import ModelInputs + +from ..base.model_inputs import ModelInputsStrategy, make_dummy_inputs + + +class ARModelInputsStrategy(ModelInputsStrategy): + + def make_dummy(self, + batch_size: int, + is_decoding: bool, + device: str = 'cpu', + dummy_block_id: int = 0, + vocab_size: int = 1) -> ModelInputs: + """Create dummy model inputs.""" + return make_dummy_inputs(batch_size, + max_q_seqlen=1, + is_decoding=is_decoding, + device=device, + dummy_block_id=dummy_block_id, + vocab_size=vocab_size) diff --git a/lmdeploy/pytorch/strategies/ar/sampling.py b/lmdeploy/pytorch/strategies/ar/sampling.py new file mode 100644 index 0000000000..b2516f091a --- /dev/null +++ b/lmdeploy/pytorch/strategies/ar/sampling.py @@ -0,0 +1,191 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from typing import List + +import torch + +from lmdeploy.pytorch.engine.logits_process import SamplingInputs +from lmdeploy.pytorch.messages import SchedulerSequence + +from ..base.sampling import SamplingStrategy + +SeqList = List[SchedulerSequence] + + +def _gather_all_ids(pad_id: int, seqs: SeqList, sampling_inputs: SamplingInputs): + """Gather history.""" + if sampling_inputs.repetition_penalty is None and not any(sampling_inputs.logits_processors): + return None + batch = len(seqs) + max_len = max(seq.num_valid_ids for seq in seqs) + output = torch.full((batch, max_len), pad_id, dtype=torch.int64) + for idx, seq in enumerate(seqs): + h_len = seq.num_valid_ids + if h_len == 0: + continue + h_ids = torch.from_numpy(seq.valid_ids) + output[idx, -h_len:] = h_ids + return output + + +def _gather_guided_input_ids(pad_id: int, seqs: SeqList, sampling_inputs: 'SamplingInputs'): + """Gather input ids for guided decode.""" + if not any(sampling_inputs.response_formats or ()): + return None + batch = len(seqs) + max_len = max(seq.num_new_tokens for seq in seqs) + output = torch.full((batch, max_len), pad_id, dtype=torch.int64) + for idx, seq in enumerate(seqs): + h_len = seq.num_new_tokens + if h_len == 0: + continue + h_ids = torch.from_numpy(seq.generated_ids) + output[idx, -h_len:] = h_ids + return output + + +def _get_num_ignore_eos(seqs: SeqList): + """Get num ignore eos.""" + ret = [seq.sampling_param.min_new_tokens - seq.num_new_tokens for seq in seqs] + return torch.tensor(ret) + + +class ARSamplingStrategy(SamplingStrategy): + """Sampling strategy for autoregressive models.""" + + def __init__(self, pad_token_id: int) -> None: + pad_token_id = 0 if pad_token_id is None else pad_token_id + self.pad_token_id = pad_token_id + + def make_sampling_inputs(self, seqs: SeqList) -> SamplingInputs: + """Create sampling inputs from the sequences.""" + batch_size = len(seqs) + temperature = [None] * batch_size + repetition_penalty = [None] * batch_size + top_k = [None] * batch_size + top_p = [None] * batch_size + min_p = [None] * batch_size + bad_words = [None] * batch_size + stop_words = [None] * batch_size + random_seeds = [torch.seed() & 0xffffffff] * batch_size + random_offsets = [None] * batch_size + response_formats = [None] * batch_size + logits_processors = [None] * batch_size + num_logprobs = [None] * batch_size + + def __gather_params(): + """Gather params.""" + for idx, seq in enumerate(seqs): + param = seq.sampling_param + temperature[idx] = param.temperature + repetition_penalty[idx] = param.repetition_penalty + top_k[idx] = param.top_k + top_p[idx] = param.top_p + min_p[idx] = param.min_p + random_offsets[idx] = seq.num_valid_ids + response_formats[idx] = param.response_format + if param.random_seed is not None: + random_seeds[idx] = param.random_seed & 0xffffffff + + bw = param.bad_words + sw = param.stop_words + if (not param.ignore_eos and seq.num_new_tokens < param.min_new_tokens): + bw = bw + sw + bad_words[idx] = bw + stop_words[idx] = sw + logits_processors[idx] = param.logits_processors + num_logprobs[idx] = param.num_logprobs + + def __get_topp(top_p): + """Get topp.""" + min_top_p = min(top_p) + if min_top_p == 1.0: + top_p = None + else: + top_p = torch.tensor(top_p) + return top_p, min_top_p + + def __get_minp(min_p): + """Get minp.""" + max_min_p = max(min_p) + if max_min_p == 0.0: + min_p = None + else: + min_p = torch.Tensor(min_p) + return min_p + + def __get_bad_words(bad_words): + """Get bad words.""" + max_bw_len = max(len(bw) for bw in bad_words) + if max_bw_len == 0: + return None, None + if all(len(bw) == max_bw_len for bw in bad_words): + ret = torch.tensor(bad_words) + mask = torch.ones_like(ret, dtype=bool) + return ret, mask + ret = torch.full((batch_size, max_bw_len), -1, dtype=torch.int64) + for idx, bw in enumerate(bad_words): + bw_len = len(bw) + if bw_len == 0: + continue + bw = ret.new_tensor(bw) + ret[idx, :bw_len] = bw + + mask = ret >= 0 + ret = ret.where(mask, 0) + return ret, mask + + __gather_params() + + if all(rp == 1.0 for rp in repetition_penalty): + repetition_penalty = None + else: + repetition_penalty = torch.tensor(repetition_penalty) + + temperature = torch.tensor(temperature) + + bad_words, bad_mask = __get_bad_words(bad_words) + stop_words, stop_mask = __get_bad_words(stop_words) + + max_top_k = max(top_k) + if min(top_k) <= 0: + max_top_k = 0 + if max_top_k == 1: + top_k = None + top_p, min_top_p = None, 1.0 + min_p = None + random_seeds = None + random_offsets = None + else: + top_k = torch.tensor(top_k) + top_p, min_top_p = __get_topp(top_p) + min_p = __get_minp(min_p) + random_seeds = torch.tensor(random_seeds) + random_offsets = torch.tensor(random_offsets) + + max_num_logprobs = max(num_logprobs) + + sampling_input = SamplingInputs( + temperature=temperature, + bad_words=bad_words, + bad_mask=bad_mask, + stop_words=stop_words, + stop_mask=stop_mask, + repetition_penalty=repetition_penalty, + top_k=top_k, + top_p=top_p, + min_p=min_p, + random_seeds=random_seeds, + random_offsets=random_offsets, + response_formats=tuple(response_formats), + max_top_k=max_top_k, + min_top_p=min_top_p, + logits_processors=logits_processors, + max_num_logprobs=max_num_logprobs, + batch_size=batch_size, + ) + + pad_token_id = self.pad_token_id + sampling_input.all_ids = _gather_all_ids(pad_token_id, seqs, sampling_input) + sampling_input.guided_input_ids = _gather_guided_input_ids(pad_token_id, seqs, sampling_input) + sampling_input.num_ignore_eos = _get_num_ignore_eos(seqs) + return sampling_input diff --git a/lmdeploy/pytorch/strategies/ar/sequence.py b/lmdeploy/pytorch/strategies/ar/sequence.py new file mode 100644 index 0000000000..91a3335f18 --- /dev/null +++ b/lmdeploy/pytorch/strategies/ar/sequence.py @@ -0,0 +1,113 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import time +from dataclasses import dataclass +from typing import Any, Dict, List, Optional + +from torch import Tensor + +from lmdeploy.pytorch.disagg.conn.protocol import MigrationRequest +from lmdeploy.pytorch.engine.model_agent import BatchedOutputs +from lmdeploy.pytorch.messages import (InputEmbeddings, MessageStatus, MultiModalInputs, SamplingParam, + SchedulerSequence, SchedulerSession, UpdateTokenMode, _to_ndarray) + +from ..base.sequence import SequenceStrategy + +SeqList = List[SchedulerSequence] + + +@dataclass +class SchedulerSequenceDefault(SchedulerSequence): + + def update_token_ids(self, + token_ids: Tensor, + multimodals: MultiModalInputs = None, + embeddings: List[InputEmbeddings] = None, + model_meta: Dict[str, Any] = None, + mode: UpdateTokenMode = UpdateTokenMode.INPUTS, + **kwargs): + """Update token ids, old token ids will be added to history.""" + # update history image nums + self._update_embeddings(embeddings) + + # update multimodals + self._update_multimodals(multimodals) + + token_ids = _to_ndarray(token_ids) + + num_valid = len(token_ids) + + if mode == UpdateTokenMode.INPUTS: + self.arrive_time = time.perf_counter() + self.output_start_pos = self.num_all_ids + len(token_ids) + self._num_token_ids += num_valid + self.num_new_tokens = 0 + else: + self._num_history_ids += self._num_token_ids + num_token_ids = num_valid + self._num_token_ids = num_token_ids + self.num_new_tokens += num_token_ids + + self.history_cache.append(token_ids) + + if model_meta is not None: + self.model_meta = model_meta + + def set_step(self, step: int): + """Set step.""" + num_all_ids = self.num_all_ids + # update step for vlm + if len(self.history_embeddings) > 0: + new_step, self._num_history_images, self._num_images = \ + self.history_embeddings.get_step(step) + assert 0 <= new_step <= step + step = new_step + self._num_history_ids = step + self._num_token_ids = num_all_ids - step + self.num_ignored_history = min(step, self.num_ignored_history) + + self.model_meta = None + + # cross + if self.history_multimodals is not None: + self._num_history_cross = self.history_multimodals.get_encoder_len(0, self.num_history_ids) + self._num_cross = self.history_multimodals.get_encoder_len(self._num_history_ids, num_all_ids) + + +class ARSequenceStrategy(SequenceStrategy): + + def make_sequence(self, + seq_id: int, + session: 'SchedulerSession', + sampling_param: 'SamplingParam' = None, + adapter_name: str = None, + migration_request: Optional[MigrationRequest] = None, + resp_cache: bool = False, + preserve_cache: bool = False) -> 'SchedulerSequence': + """Make sequence.""" + return SchedulerSequenceDefault(seq_id=seq_id, + session=session, + sampling_param=sampling_param, + adapter_name=adapter_name, + migration_request=migration_request, + resp_cache=resp_cache, + preserve_cache=preserve_cache) + + def update_running(self, running: SeqList, batched_outputs: BatchedOutputs, is_decoding: bool) -> None: + """Update running sequences.""" + next_token_ids = batched_outputs.next_token_ids + stopped = batched_outputs.stopped + stopped = stopped.tolist() + model_metas = batched_outputs.model_metas + if model_metas is None: + model_metas = [None] * len(running) + + next_token_ids = next_token_ids.numpy() + update_mode = UpdateTokenMode.DECODE if is_decoding else UpdateTokenMode.PREFILL + for token, msg, stop, model_meta in zip(next_token_ids, running, stopped, model_metas): + if msg.status != MessageStatus.LOCKED: + continue + + # fill token + msg.update_token_ids(token, model_meta=model_meta, mode=update_mode) + if stop: + msg.status = MessageStatus.TO_BE_MIGRATED if msg.preserve_cache else MessageStatus.STOPPED diff --git a/lmdeploy/pytorch/strategies/base/__init__.py b/lmdeploy/pytorch/strategies/base/__init__.py new file mode 100644 index 0000000000..42519779ad --- /dev/null +++ b/lmdeploy/pytorch/strategies/base/__init__.py @@ -0,0 +1,47 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from abc import ABC, abstractmethod +from typing import TYPE_CHECKING + +if TYPE_CHECKING: + from lmdeploy.pytorch.config import CacheConfig, SchedulerConfig + + from .cudagraph import CudagraphStrategy + from .engine import EngineStrategy + from .model_agent import ModelAgentStrategy + from .model_inputs import ModelInputsStrategy + from .sampling import SamplingStrategy + from .sequence import SequenceStrategy + + +class StrategyFactoryBase(ABC): + + @abstractmethod + def build_cudagraph_strategy(self) -> 'CudagraphStrategy': + """Build cudagraph strategy.""" + pass + + @abstractmethod + def build_sampling_strategy(self) -> 'SamplingStrategy': + """Build sampling strategy.""" + pass + + @abstractmethod + def build_model_inputs_strategy(self) -> 'ModelInputsStrategy': + """Build model inputs strategy.""" + pass + + @abstractmethod + def build_model_agent_strategy(self) -> 'ModelAgentStrategy': + """Build model agent strategy.""" + pass + + @abstractmethod + def build_engine_strategy(self, cache_config: 'CacheConfig', + scheduler_config: 'SchedulerConfig') -> 'EngineStrategy': + """Build engine strategy.""" + pass + + @abstractmethod + def build_sequence_strategy(self) -> 'SequenceStrategy': + """Build sequence strategy.""" + pass diff --git a/lmdeploy/pytorch/strategies/base/cudagraph.py b/lmdeploy/pytorch/strategies/base/cudagraph.py new file mode 100644 index 0000000000..795c3b5350 --- /dev/null +++ b/lmdeploy/pytorch/strategies/base/cudagraph.py @@ -0,0 +1,10 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from abc import ABC, abstractmethod + + +class CudagraphStrategy(ABC): + + @abstractmethod + def get_max_tokens(self, batch_size: int) -> int: + """Get max tokens.""" + pass diff --git a/lmdeploy/pytorch/strategies/base/engine.py b/lmdeploy/pytorch/strategies/base/engine.py new file mode 100644 index 0000000000..1c8e0c4ef7 --- /dev/null +++ b/lmdeploy/pytorch/strategies/base/engine.py @@ -0,0 +1,16 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from abc import ABC, abstractmethod + + +class EngineStrategy(ABC): + """Engine strategy.""" + + @abstractmethod + def get_prealloc_size(self, is_decoding: bool) -> int: + """Get prealloc_size.""" + pass + + @abstractmethod + def get_num_loops(self, is_decoding: bool) -> int: + """Get num_loops.""" + pass diff --git a/lmdeploy/pytorch/strategies/base/model_agent.py b/lmdeploy/pytorch/strategies/base/model_agent.py new file mode 100644 index 0000000000..8701a256fa --- /dev/null +++ b/lmdeploy/pytorch/strategies/base/model_agent.py @@ -0,0 +1,132 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from abc import ABC, abstractmethod +from dataclasses import dataclass, fields +from typing import TYPE_CHECKING, Any, List, Optional + +import numpy as np +import torch + +if TYPE_CHECKING: + from lmdeploy.pytorch.engine.logits_process import SamplingInputs + from lmdeploy.pytorch.messages import SchedulerSequence + from lmdeploy.pytorch.model_inputs import ModelInputs + SeqList = List[SchedulerSequence] + + +def to_device(self, device: str, non_blocking: bool = False): + """To device.""" + out_dict = dict() + for f in fields(self): + k = f.name + v = getattr(self, k) + if isinstance(v, torch.Tensor): + v = v.to(device, non_blocking=non_blocking) + out_dict[k] = v + + return type(self)(**out_dict) + + +@dataclass +class ExtraInputs(ABC): + + def to_device(self, device: str, non_blocking: bool = False): + """To device.""" + return to_device(self, device, non_blocking) + + +@dataclass +class ExtraOutputs(ABC): + + def to_device(self, device: str, non_blocking: bool = False): + """To device.""" + return to_device(self, device, non_blocking) + + def to_cpu(self): + """To cpu.""" + return self.to_device('cpu', non_blocking=False) + + def to_numpy(self): + """To numpy.""" + out = dict() + for f in fields(self): + k = f.name + v = getattr(self, k) + if isinstance(v, torch.Tensor) and v.dtype != torch.bfloat16: + v = v.detach().numpy() + elif hasattr(v, 'to_numpy'): + v = v.to_numpy() + out[k] = v + return type(self)(**out) + + def to_tensor(self): + """To tensor.""" + out = dict() + for f in fields(self): + k = f.name + v = getattr(self, k) + if isinstance(v, np.ndarray): + v = torch.from_numpy(v) + elif hasattr(v, 'to_tensor'): + v = v.to_tensor() + out[k] = v + return type(self)(**out) + + +@dataclass +class StoppingCriteria(ABC): + """Base class for stopping criteria.""" + + @abstractmethod + def step(self, + token_ids: torch.Tensor, + stop_words: torch.Tensor, + inputs: Optional['ModelInputs'] = None, + extra_inputs: Optional[ExtraInputs] = None): + """Check whether to stop generation.""" + pass + + def to_device(self, device: str, non_blocking: bool = False): + """To device.""" + return to_device(self, device, non_blocking) + + +class ModelAgentStrategy(ABC): + """Base class for model agent strategies.""" + + @abstractmethod + def slice_outputs(self, inputs: torch.Tensor, seq_length: torch.LongTensor) -> torch.Tensor: + """Slice outputs.""" + pass + + @abstractmethod + def slice_extra_inputs(self, extra_inputs: ExtraInputs, seq_length: torch.LongTensor) -> ExtraInputs: + """Slice outputs.""" + pass + + @abstractmethod + def make_stopping_criteria(self, seqs: 'SeqList') -> StoppingCriteria: + """Create stopping criteria.""" + pass + + @abstractmethod + def make_extra_inputs(self, seqs: 'SeqList') -> ExtraInputs: + """Create extra inputs.""" + pass + + @abstractmethod + def make_extra_outputs(self, extra_inputs: ExtraInputs) -> ExtraOutputs: + """Create extra outputs.""" + pass + + @abstractmethod + def update_inputs_for_next_step(self, model_inputs: 'ModelInputs', sampling_inputs: 'SamplingInputs', + next_token_ids: torch.Tensor, model_metas: Any, extra_inputs: ExtraInputs, + **kwargs): + """Step next inputs.""" + pass + + @abstractmethod + def post_sampling(self, inputs: 'ModelInputs', logits: torch.Tensor, next_token_ids: torch.LongTensor, + extra_inputs: ExtraInputs): + """Post sampling.""" + pass diff --git a/lmdeploy/pytorch/strategies/base/model_inputs.py b/lmdeploy/pytorch/strategies/base/model_inputs.py new file mode 100644 index 0000000000..f27134077d --- /dev/null +++ b/lmdeploy/pytorch/strategies/base/model_inputs.py @@ -0,0 +1,54 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from abc import ABC, abstractmethod + +import torch +from torch.profiler import record_function + +from lmdeploy.pytorch.model_inputs import ModelInputs + + +@record_function('make_dummy_input') +def make_dummy_inputs(batch_size: int, + max_q_seqlen: int, + is_decoding: bool, + device: str = 'cpu', + dummy_block_id: int = 0, + vocab_size: int = 1): + """Make dummy inputs global implement.""" + num_tokens = batch_size * max_q_seqlen + max_kv_seqlen = max_q_seqlen + input_ids = torch.randint(0, vocab_size, ( + 1, + num_tokens, + ), dtype=torch.long, device=device) + seq_length = torch.full((batch_size, ), max_q_seqlen, dtype=torch.long, device=device) + history_lengths = torch.zeros((batch_size, ), dtype=torch.long, device=device) + block_offsets = torch.full((batch_size, 1), dummy_block_id, dtype=torch.long, device=device) + num_ignored_history = torch.zeros((batch_size, ), dtype=torch.long, device=device) + local_adapter_ids = torch.zeros((batch_size, ), dtype=torch.long, device=device) + + return ModelInputs( + input_ids=input_ids, + seq_length=seq_length, + history_lengths=history_lengths, + block_offsets=block_offsets, + is_decoding=is_decoding, + num_ignored_history=num_ignored_history, + max_q_seqlen=max_q_seqlen, + max_kv_seqlen=max_kv_seqlen, + sum_kv_seqlen=batch_size, + local_adapter_ids=local_adapter_ids, + ) + + +class ModelInputsStrategy(ABC): + + @abstractmethod + def make_dummy(self, + batch_size: int, + is_decoding: bool, + device: str = 'cpu', + dummy_block_id: int = 0, + vocab_size: int = 1) -> ModelInputs: + """Create dummy model inputs.""" + pass diff --git a/lmdeploy/pytorch/strategies/base/sampling.py b/lmdeploy/pytorch/strategies/base/sampling.py new file mode 100644 index 0000000000..172454157b --- /dev/null +++ b/lmdeploy/pytorch/strategies/base/sampling.py @@ -0,0 +1,17 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from abc import ABC, abstractmethod +from typing import List + +from lmdeploy.pytorch.engine.logits_process import SamplingInputs +from lmdeploy.pytorch.messages import SchedulerSequence + +SeqList = List[SchedulerSequence] + + +class SamplingStrategy(ABC): + """Base class for sampling strategies.""" + + @abstractmethod + def make_sampling_inputs(self, seqs: SeqList) -> SamplingInputs: + """Create sampling inputs from the sequences.""" + pass diff --git a/lmdeploy/pytorch/strategies/base/sequence.py b/lmdeploy/pytorch/strategies/base/sequence.py new file mode 100644 index 0000000000..408a3cc15e --- /dev/null +++ b/lmdeploy/pytorch/strategies/base/sequence.py @@ -0,0 +1,30 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from abc import ABC, abstractmethod +from typing import TYPE_CHECKING, List, Optional + +from lmdeploy.pytorch.disagg.conn.protocol import MigrationRequest + +if TYPE_CHECKING: + from lmdeploy.pytorch.engine.model_agent import BatchedOutputs + from lmdeploy.pytorch.messages import SamplingParam, SchedulerSequence, SchedulerSession + SeqList = List[SchedulerSequence] + + +class SequenceStrategy(ABC): + + @abstractmethod + def make_sequence(self, + seq_id: int, + session: 'SchedulerSession', + sampling_param: 'SamplingParam' = None, + adapter_name: str = None, + migration_request: Optional[MigrationRequest] = None, + resp_cache: bool = False, + preserve_cache: bool = False) -> 'SchedulerSequence': + """Make sequence.""" + pass + + @abstractmethod + def update_running(self, running: 'SeqList', batched_outputs: 'BatchedOutputs', is_decoding: bool) -> None: + """Update running sequences.""" + pass diff --git a/lmdeploy/pytorch/strategies/dllm/__init__.py b/lmdeploy/pytorch/strategies/dllm/__init__.py new file mode 100644 index 0000000000..dc0395a017 --- /dev/null +++ b/lmdeploy/pytorch/strategies/dllm/__init__.py @@ -0,0 +1,84 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from typing import TYPE_CHECKING + +from lmdeploy.pytorch.config import DLLMConfig, ModelConfig +from lmdeploy.pytorch.strategies.base.sequence import SequenceStrategy +from lmdeploy.utils import get_logger + +if TYPE_CHECKING: + from lmdeploy.pytorch.strategies.base.cudagraph import CudagraphStrategy + from lmdeploy.pytorch.strategies.base.model_inputs import ModelInputsStrategy + from lmdeploy.pytorch.strategies.base.sampling import SamplingStrategy + from lmdeploy.pytorch.strategies.base.model_agent import ModelAgentStrategy + from lmdeploy.pytorch.strategies.base.engine import EngineStrategy + from lmdeploy.pytorch.config import CacheConfig, SchedulerConfig + +from ..base import StrategyFactoryBase + +logger = get_logger('lmdeploy') + + +class DLLMStrategyFactory(StrategyFactoryBase): + + def __init__(self, model_config: ModelConfig, dllm_config: DLLMConfig): + """config.""" + self.model_config = model_config + self.dllm_config = dllm_config + + # update dllm_block_length + self.dllm_block_length = self._update_dllm_block_length() + + def _update_dllm_block_length(self): + """Update dllm_block_length.""" + if self.dllm_config.block_length is None: + dllm_block_length = self.model_config.dllm_block_length + if dllm_block_length is None: + dllm_block_length = 4 + logger.warning('Model does not provide dllm_block_length. ' + f'Set dllm_block_length={dllm_block_length} as default.') + else: + dllm_block_length = self.dllm_config.block_length + + assert dllm_block_length is not None, 'dllm_block_length should be set in model_config or dllm_config' + + self.dllm_config.block_length = dllm_block_length + self.model_config.dllm_block_length = dllm_block_length + + if self.dllm_config.denoising_steps is None: + self.dllm_config.denoising_steps = dllm_block_length + return dllm_block_length + + def build_cudagraph_strategy(self) -> 'CudagraphStrategy': + """Build cudagraph strategy.""" + from .cudagraph import DLLMCudagraphStrategy + return DLLMCudagraphStrategy(block_size=self.dllm_block_length) + + def build_sampling_strategy(self) -> 'SamplingStrategy': + """Build sampling strategy.""" + from .sampling import DLLMSamplingStrategy + pad_token_id = self.model_config.bos_token_id + pad_token_id = 0 if pad_token_id is None else pad_token_id + return DLLMSamplingStrategy(pad_token_id, self.dllm_block_length) + + def build_model_inputs_strategy(self) -> 'ModelInputsStrategy': + """Build model inputs strategy.""" + from .model_inputs import DLLMModelInputsStrategy + return DLLMModelInputsStrategy(block_size=self.dllm_block_length) + + def build_model_agent_strategy(self) -> 'ModelAgentStrategy': + """Build model agent strategy.""" + from .model_agent import DLLMModelAgentStrategy + return DLLMModelAgentStrategy(dllm_config=self.dllm_config, dllm_mask_token=self.model_config.dllm_mask_token) + + def build_engine_strategy(self, cache_config: 'CacheConfig', + scheduler_config: 'SchedulerConfig') -> 'EngineStrategy': + """Build engine strategy.""" + from .engine import DLLMEngineStrategy + return DLLMEngineStrategy(cache_config=cache_config, + scheduler_config=scheduler_config, + dllm_block_length=self.dllm_block_length) + + def build_sequence_strategy(self) -> SequenceStrategy: + from .sequence import DLLMSequenceStrategy + return DLLMSequenceStrategy(block_size=self.dllm_block_length, + dllm_mask_token=self.model_config.dllm_mask_token) diff --git a/lmdeploy/pytorch/strategies/dllm/cudagraph.py b/lmdeploy/pytorch/strategies/dllm/cudagraph.py new file mode 100644 index 0000000000..2e388b22de --- /dev/null +++ b/lmdeploy/pytorch/strategies/dllm/cudagraph.py @@ -0,0 +1,13 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from ..base.cudagraph import CudagraphStrategy + + +class DLLMCudagraphStrategy(CudagraphStrategy): + + def __init__(self, block_size: int) -> None: + super().__init__() + self.block_size = block_size + + def get_max_tokens(self, batch_size: int) -> int: + """Get max tokens.""" + return batch_size * self.block_size diff --git a/lmdeploy/pytorch/strategies/dllm/engine.py b/lmdeploy/pytorch/strategies/dllm/engine.py new file mode 100644 index 0000000000..32244a7abc --- /dev/null +++ b/lmdeploy/pytorch/strategies/dllm/engine.py @@ -0,0 +1,50 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from functools import lru_cache + +from lmdeploy.pytorch.config import CacheConfig, SchedulerConfig +from lmdeploy.utils import get_logger + +from ..base.engine import EngineStrategy + +logger = get_logger('lmdeploy') + + +class DLLMEngineStrategy(EngineStrategy): + """DLLM Engine Strategy.""" + + def __init__(self, scheduler_config: SchedulerConfig, cache_config: CacheConfig, dllm_block_length: int) -> None: + self.scheduler_config = scheduler_config + self.cache_config = cache_config + self.dllm_block_length = dllm_block_length + + self._check() + + def _check(self): + """check.""" + max_prefill_token_num = self.cache_config.max_prefill_token_num + max_batches = self.cache_config.max_batches + if self.dllm_block_length * max_batches > max_prefill_token_num: + logger.warning(f'dllm_block_length({self.dllm_block_length}) * max_batch_size ({max_batches}) ' + f'> max_prefill_token_num ({max_prefill_token_num}). ' + 'This may lead to OOM. Consider to reduce max_batch_size or dllm_block_length.') + + @lru_cache(maxsize=2) + def get_prealloc_size(self, is_decoding: bool) -> int: + """Get prealloc_size.""" + if not is_decoding: + return 0 + block_size = self.cache_config.block_size + dllm_block_length = self.dllm_block_length + num_blocks = min(self.scheduler_config.prefill_interval // 2, block_size // dllm_block_length) + return num_blocks * dllm_block_length + + @lru_cache(maxsize=2) + def get_num_loops(self, is_decoding: bool) -> int: + """Get num_loops.""" + if not is_decoding: + return 1 + block_size = self.cache_config.block_size + dllm_block_length = self.dllm_block_length + max_num_loops = block_size // dllm_block_length * 2 + num_loops = min(self.scheduler_config.prefill_interval, max_num_loops) + return num_loops diff --git a/lmdeploy/pytorch/strategies/dllm/model_agent.py b/lmdeploy/pytorch/strategies/dllm/model_agent.py new file mode 100644 index 0000000000..a5104d4981 --- /dev/null +++ b/lmdeploy/pytorch/strategies/dllm/model_agent.py @@ -0,0 +1,218 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from dataclasses import dataclass +from typing import Any, List, Optional + +import numpy as np +import torch +from torch.profiler import record_function + +from lmdeploy.pytorch import consts +from lmdeploy.pytorch.config import DLLMConfig +from lmdeploy.pytorch.engine.logits_process import SamplingInputs +from lmdeploy.pytorch.messages import SchedulerSequence +from lmdeploy.pytorch.model_inputs import ModelInputs + +from ..base.model_agent import ExtraInputs, ExtraOutputs, ModelAgentStrategy, StoppingCriteria +from .unmasking import UnmaskingProcessor + +SeqList = List[SchedulerSequence] + + +@dataclass +class DLLMExtraInputs(ExtraInputs): + """DLLM extra inputs.""" + dllm_mask: torch.Tensor + + +@dataclass +class DLLMExtraOutputs(ExtraOutputs): + """Ar extra outputs.""" + dllm_mask: torch.Tensor + + +def _check_stopwords_dllm(token_ids: torch.Tensor, stop_words: torch.Tensor, is_unmasked: torch.Tensor, + stopped: torch.Tensor, stop_pos: torch.Tensor, num_appendable_ids: torch.Tensor, + output_start_pos: torch.Tensor, inputs: ModelInputs): + num_tokens = token_ids.size(0) + batch_size = num_appendable_ids.size(0) + block_size = num_tokens // batch_size + + # blocks might contain stop words in prev-round chat + # these stop words should be ignored + kv_seqlens = inputs.history_lengths + inputs.seq_length + ignore_pos = (output_start_pos - (kv_seqlens - block_size)).clamp_min(0) + ignore_range = torch.arange(0, block_size, dtype=ignore_pos.dtype, device=ignore_pos.device) + ignore_mask = (ignore_range[None, :] < ignore_pos[:, None]).flatten() + token_ids = token_ids.clone() + token_ids[ignore_mask] = -1 + + # find stop words + sw_stopped = (token_ids[:, None] == stop_words).any(1) + sw_stopped = sw_stopped.view(batch_size, block_size) + sw_stop_pos = sw_stopped.int().argmax(1) + + stop_pos = torch.where(stopped, stop_pos, sw_stop_pos) + sw_stopped = sw_stopped.any(dim=1) + sw_stopped = sw_stopped & is_unmasked + stopped = stopped | sw_stopped + + # update num_appendable_ids + one_ids = torch.clamp_max(num_appendable_ids, 0) + num_appendable_ids = torch.where(sw_stopped, one_ids, num_appendable_ids) + + return stopped, stop_pos, num_appendable_ids + + +@dataclass +class DLLMStoppingCriteria(StoppingCriteria): + num_appendable_ids: torch.Tensor + output_start_pos: torch.Tensor + + @record_function('stopping_criteria') + def step(self, + token_ids: torch.Tensor, + stop_words: torch.Tensor, + inputs: Optional[ModelInputs] = None, + extra_inputs: Optional[DLLMExtraInputs] = None): + """Check whether to stop generation.""" + num_appendable_ids = self.num_appendable_ids + output_start_pos = self.output_start_pos + num_tokens = token_ids.size(0) + batch_size = num_appendable_ids.size(0) + block_size = num_tokens // batch_size + + dllm_mask = extra_inputs.dllm_mask + dllm_mask = dllm_mask.view(batch_size, block_size) + is_unmasked = (dllm_mask == consts.DLLM_UNMASKED).all(dim=1) + + # check stop by num_new_tokens + num_appendable_ids -= is_unmasked * block_size + stopped = num_appendable_ids <= 0 + stop_pos = block_size - 1 + num_appendable_ids + + # check stop words + if stop_words is not None: + stopped, stop_pos, num_appendable_ids = _check_stopwords_dllm(token_ids, + stop_words, + is_unmasked, + stopped, + stop_pos, + num_appendable_ids, + output_start_pos=output_start_pos, + inputs=inputs) + + new_stopping = DLLMStoppingCriteria(num_appendable_ids=num_appendable_ids, output_start_pos=output_start_pos) + return stopped, stop_pos, new_stopping + + +class DLLMModelAgentStrategy(ModelAgentStrategy): + + def __init__(self, dllm_config: DLLMConfig, dllm_mask_token: int): + block_size = dllm_config.block_length + self.block_size = block_size + self.dllm_mask_token = dllm_mask_token + + self.unmasking_processor = UnmaskingProcessor(dllm_config=dllm_config) + + def _update_dllm(self, next_token_ids: torch.Tensor, dllm_mask: torch.Tensor, seqlens: torch.Tensor): + """Update token_ids and dllm_mask.""" + dllm_mask_token = self.dllm_mask_token + dllm_block_length = self.block_size + + # reshape to (batch, dllm_block_length) + next_token_ids = next_token_ids.view(-1, dllm_block_length).clone() + dllm_mask = dllm_mask.view(-1, dllm_block_length).clone() + + # flags + is_cached = (dllm_mask == consts.DLLM_CACHED).all(dim=1) + + is_masked = (dllm_mask == consts.DLLM_MASKED) + next_token_ids[is_cached[:, None] | is_masked] = dllm_mask_token + dllm_mask[is_cached] = consts.DLLM_MASKED + seqlens = torch.where(is_cached.view(-1), seqlens, seqlens.new_zeros((1, ))) + + return next_token_ids.flatten(), dllm_mask.flatten(), seqlens + + def slice_outputs(self, inputs: torch.Tensor, seq_length: torch.LongTensor) -> torch.Tensor: + """Slice outputs.""" + block_length = self.block_size + # batch size = 1 + if len(seq_length) == 1: + return inputs[-block_length:] + + if len(seq_length) * block_length == inputs.size(0): + return inputs + last_idx = seq_length.cumsum(0) + block_range = torch.arange(-block_length, 0, device=last_idx.device) + index = (last_idx[:, None] + block_range[None, :]).flatten() + inputs = inputs[index] + return inputs + + def slice_extra_inputs(self, extra_inputs: DLLMExtraInputs, seq_length: torch.LongTensor) -> DLLMExtraInputs: + """Slice outputs.""" + dllm_mask = self.slice_outputs(extra_inputs.dllm_mask, seq_length) + return DLLMExtraInputs(dllm_mask=dllm_mask) + + def _step_sampling_inputs(self, sampling_inputs: SamplingInputs, next_token_ids: torch.Tensor, + dllm_mask: torch.Tensor, **kwargs): + """step.""" + from lmdeploy.pytorch import consts + dllm_block_size = self.block_size + DLLM_UNMASKED = consts.DLLM_UNMASKED + is_unmasked = (dllm_mask == DLLM_UNMASKED).view(-1, dllm_block_size).all(dim=1, keepdim=True) + num_ignore_eos = sampling_inputs.num_ignore_eos.view(-1, dllm_block_size) + num_ignore_eos = torch.where(is_unmasked, num_ignore_eos - dllm_block_size, num_ignore_eos) + sampling_inputs.num_ignore_eos = num_ignore_eos.flatten() + return sampling_inputs + + def make_stopping_criteria(self, seqs: SeqList) -> DLLMStoppingCriteria: + """Create stopping criteria.""" + # num_appendable + num_appendable = [seq.sampling_param.max_new_tokens - seq.num_new_tokens for seq in seqs] + num_appendable = torch.tensor(num_appendable) + block_size = self.block_size + remain = [seq.num_valid_ids % block_size for seq in seqs] + num_appendable += torch.tensor(remain) + + # output_start_pos + pos = [seq.output_start_pos for seq in seqs] + output_start_pos = torch.tensor(pos) + + return DLLMStoppingCriteria(num_appendable_ids=num_appendable, output_start_pos=output_start_pos) + + def make_extra_inputs(self, seqs: 'SeqList') -> ExtraInputs: + """Create extra inputs.""" + dllm_masks = [seq.dllm_mask for seq in seqs] + dllm_masks = torch.as_tensor(np.concatenate(dllm_masks)) + return DLLMExtraInputs(dllm_mask=dllm_masks) + + def make_extra_outputs(self, extra_inputs: DLLMExtraInputs) -> DLLMExtraOutputs: + """Create extra outputs.""" + dllm_mask = extra_inputs.dllm_mask + return DLLMExtraOutputs(dllm_mask=dllm_mask) + + def update_inputs_for_next_step(self, model_inputs: 'ModelInputs', sampling_inputs: 'SamplingInputs', + next_token_ids: torch.Tensor, model_metas: Any, extra_inputs: DLLMExtraInputs, + **kwargs): + """Step next inputs.""" + model_inputs.model_metas = model_metas + dllm_mask = extra_inputs.dllm_mask + + next_token_ids, dllm_mask, step_seqlens = self._update_dllm(next_token_ids, dllm_mask, model_inputs.seq_length) + model_inputs.step(next_token_ids, step_seqlens) + self._step_sampling_inputs(sampling_inputs, next_token_ids, dllm_mask=dllm_mask) + + extra_inputs = DLLMExtraInputs(dllm_mask=dllm_mask) + return model_inputs, extra_inputs + + def post_sampling(self, inputs: 'ModelInputs', logits: torch.Tensor, next_token_ids: torch.LongTensor, + extra_inputs: DLLMExtraInputs): + """Post sampling.""" + dllm_mask = extra_inputs.dllm_mask + input_ids = inputs.input_ids + input_ids = self.slice_outputs(input_ids.flatten(), inputs.seq_length) + + dllm_mask, next_token_ids = self.unmasking_processor(logits, input_ids, next_token_ids, dllm_mask) + + extra_inputs.dllm_mask = dllm_mask + return next_token_ids, extra_inputs diff --git a/lmdeploy/pytorch/strategies/dllm/model_inputs.py b/lmdeploy/pytorch/strategies/dllm/model_inputs.py new file mode 100644 index 0000000000..f05ba415f2 --- /dev/null +++ b/lmdeploy/pytorch/strategies/dllm/model_inputs.py @@ -0,0 +1,24 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from lmdeploy.pytorch.model_inputs import ModelInputs + +from ..base.model_inputs import ModelInputsStrategy, make_dummy_inputs + + +class DLLMModelInputsStrategy(ModelInputsStrategy): + + def __init__(self, block_size: int): + self.block_size = block_size + + def make_dummy(self, + batch_size: int, + is_decoding: bool, + device: str = 'cpu', + dummy_block_id: int = 0, + vocab_size: int = 1) -> ModelInputs: + """Create dummy model inputs.""" + return make_dummy_inputs(batch_size, + max_q_seqlen=self.block_size, + is_decoding=is_decoding, + device=device, + dummy_block_id=dummy_block_id, + vocab_size=vocab_size) diff --git a/lmdeploy/pytorch/strategies/dllm/sampling.py b/lmdeploy/pytorch/strategies/dllm/sampling.py new file mode 100644 index 0000000000..2ad5d5ecd7 --- /dev/null +++ b/lmdeploy/pytorch/strategies/dllm/sampling.py @@ -0,0 +1,57 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from typing import List + +from lmdeploy.pytorch.engine.logits_process import SamplingInputs +from lmdeploy.pytorch.messages import SchedulerSequence + +from ..ar.sampling import ARSamplingStrategy + +SeqList = List[SchedulerSequence] + + +class DLLMSamplingStrategy(ARSamplingStrategy): + """Sampling strategy for autoregressive models.""" + + def __init__(self, pad_token_id: int, dllm_block_length: int) -> None: + super().__init__(pad_token_id) + self.dllm_block_length = dllm_block_length + + def make_sampling_inputs(self, seqs: SeqList) -> SamplingInputs: + """Create sampling inputs from the sequences.""" + out = super().make_sampling_inputs(seqs) + dllm_block_length = self.dllm_block_length + + # repeat tensor + update_attr_names = [ + 'temperature', + 'bad_words', + 'bad_mask', + 'stop_words', + 'stop_mask', + 'repetition_penalty', + 'top_k', + 'top_p', + 'min_p', + 'random_seeds', + 'random_offsets', + 'all_ids', + 'guided_input_ids', + 'num_ignore_eos', + ] + for name in update_attr_names: + attr = getattr(out, name) + if attr is None: + continue + repeats = (dllm_block_length, ) + (1, ) * (attr.dim()) + attr = attr[None].repeat(*repeats).flatten(0, 1) + setattr(out, name, attr) + + if len(out.response_formats) > 0: + new_resp_formats = [] + for resp in out.response_formats: + new_resp_formats += [resp] * dllm_block_length + out.response_formats = tuple(new_resp_formats) + + out.batch_size *= dllm_block_length + + return out diff --git a/lmdeploy/pytorch/strategies/dllm/sequence.py b/lmdeploy/pytorch/strategies/dllm/sequence.py new file mode 100644 index 0000000000..ab004a2b63 --- /dev/null +++ b/lmdeploy/pytorch/strategies/dllm/sequence.py @@ -0,0 +1,248 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import time +from dataclasses import dataclass, field +from typing import Any, Dict, List, Optional + +import numpy as np +from torch import Tensor + +from lmdeploy.pytorch import consts +from lmdeploy.pytorch.disagg.conn.protocol import MigrationRequest +from lmdeploy.pytorch.engine.model_agent import BatchedOutputs +from lmdeploy.pytorch.messages import (HistoryTokenIds, InputEmbeddings, MessageStatus, MultiModalInputs, SamplingParam, + SchedulerSession, UpdateTokenMode, _to_ndarray) + +from ..ar.sequence import SchedulerSequenceDefault +from ..base.sequence import SequenceStrategy + +SeqList = List['SchedulerSequenceDLLM'] + +DLLM_MASKED = consts.DLLM_MASKED +DLLM_UNMASKED = consts.DLLM_UNMASKED +DLLM_CACHED = consts.DLLM_CACHED +DLLM_MASK_DTYPE = np.uint8 + + +class HistoryDLLMMask(HistoryTokenIds): + + def __init__(self, token_ids: np.ndarray = None, dtype: np.dtype = DLLM_MASK_DTYPE): + super().__init__(token_ids=token_ids, dtype=dtype) + + +@dataclass +class SchedulerSequenceDLLM(SchedulerSequenceDefault): + + # For dllm + history_dllm_mask: HistoryDLLMMask = field(default_factory=HistoryDLLMMask) + + def __post_init__(self): + """Post init.""" + super().__post_init__() + self._num_valid_ids: int = len(self.history_cache) + self._strategy: DLLMSequenceStrategy = self._seq_meta.strategy + + @property + def dllm_mask(self): + start = self.num_history_ids + end = start + self._num_token_ids + return self.history_dllm_mask._token_ids[start:end] + + @property + def num_valid_ids(self): + return self._num_valid_ids + + @property + def generated_ids(self) -> np.ndarray: + end = self.num_valid_ids + start = end - self.num_new_tokens + return self.history_cache._token_ids[start:end] + + @property + def all_dllm_mask(self): + return self.history_dllm_mask._token_ids[:self.num_all_ids] + + @property + def dllm_block_length(self): + return self._strategy.block_size + + @property + def dllm_mask_token(self): + return self._strategy.dllm_mask_token + + def set_stop_pos(self, pos: int): + dllm_block_length = self.dllm_block_length + val = dllm_block_length - pos - 1 + self._num_valid_ids -= val + self.num_new_tokens -= val + + def _update_token_ids_inputs(self, token_ids: np.ndarray, dllm_mask: np.ndarray): + """Append tokens.""" + num_tokens = len(token_ids) + dllm_block_length = self.dllm_block_length + dllm_mask_token = self.dllm_mask_token + new_token_ids = [token_ids] + new_dllm_mask = [dllm_mask] + + # add uncached tokens in token_ids + # for example, [cccc cccc uumm], the [uu] in last block is remain valid. + num_remain_valid = self.num_valid_ids - self.num_history_ids + if num_remain_valid != 0: + prev_token_ids = self.valid_ids[-num_remain_valid:] + prev_dllm_mask = np.full_like(prev_token_ids, DLLM_UNMASKED, dtype=DLLM_MASK_DTYPE) + new_token_ids = [prev_token_ids] + new_token_ids + new_dllm_mask = [prev_dllm_mask] + new_dllm_mask + self.history_cache.resize(self.num_history_ids) + self.history_dllm_mask.resize(self.num_history_ids) + num_tokens += num_remain_valid + + # pad to align with dllm_block_length + num_pad = (-num_tokens) % dllm_block_length + if num_pad > 0: + pad_ids = np.full_like(token_ids, dllm_mask_token, shape=(num_pad, )) + pad_mask = np.full_like(dllm_mask, DLLM_MASKED, shape=(num_pad, )) + new_token_ids += [pad_ids] + new_dllm_mask += [pad_mask] + + token_ids = np.concatenate(new_token_ids) + dllm_mask = np.concatenate(new_dllm_mask) + + assert len(token_ids) % dllm_block_length == 0 + + self.history_cache.append(token_ids) + self.history_dllm_mask.append(dllm_mask) + self.output_start_pos = self._num_valid_ids + len(token_ids) + self._num_valid_ids = self.num_history_ids + num_tokens + self._num_token_ids = len(token_ids) + self.num_new_tokens = 0 + + def _update_token_ids_decode(self, token_ids: np.ndarray, dllm_mask: np.ndarray): + """Update token ids for decode.""" + num_tokens = len(token_ids) + dllm_block_length = self.dllm_block_length + dllm_mask_token = self.dllm_mask_token + assert num_tokens % dllm_block_length == 0 + num_history_ids = self.num_history_ids + + token_ids[dllm_mask == DLLM_MASKED] = dllm_mask_token + self.history_cache[num_history_ids:] = token_ids + self.history_dllm_mask[num_history_ids:] = dllm_mask + + # check if all blocks are cached + last_mask = dllm_mask[-dllm_block_length:] + is_unmasked = np.all(last_mask == DLLM_UNMASKED) + is_cached = np.all(last_mask == DLLM_CACHED) + + if is_unmasked: + num_new = dllm_block_length - self._num_valid_ids % dllm_block_length + self._num_valid_ids += num_new + self.num_new_tokens += num_new + + if is_cached: + # add new block + new_token_ids = np.full_like(token_ids, dllm_mask_token, shape=(dllm_block_length, )) + new_dllm_mask = np.full_like(dllm_mask, DLLM_MASKED, shape=(dllm_block_length, )) + self.history_cache.append(new_token_ids) + self.history_dllm_mask.append(new_dllm_mask) + self._num_history_ids += self._num_token_ids + self._num_token_ids = dllm_block_length + + def _update_token_ids_prefill(self, token_ids: np.ndarray, dllm_mask: np.ndarray): + """Update token ids for prefill.""" + dllm_block_length = self.dllm_block_length + num_history_ids = self.num_history_ids + + # fill input cache + if self.num_token_ids > dllm_block_length: + end = self.num_token_ids - dllm_block_length + self.history_dllm_mask[num_history_ids:end] = DLLM_CACHED + self._num_history_ids += end + self._num_token_ids -= end + + # decoding update + self._update_token_ids_decode(token_ids, dllm_mask) + + def update_token_ids(self, + token_ids: Tensor, + multimodals: MultiModalInputs = None, + embeddings: List[InputEmbeddings] = None, + model_meta: Dict[str, Any] = None, + dllm_mask: Tensor = None, + mode: UpdateTokenMode = UpdateTokenMode.INPUTS, + **kwargs): + """Update token ids, old token ids will be added to history.""" + # update history image nums + self._update_embeddings(embeddings) + + # update multimodals + self._update_multimodals(multimodals) + + self.arrive_time = time.perf_counter() + + token_ids: np.ndarray = _to_ndarray(token_ids) + if dllm_mask is None: + dllm_mask = np.full_like(token_ids, DLLM_UNMASKED, dtype=DLLM_MASK_DTYPE) + dllm_mask: np.ndarray = _to_ndarray(dllm_mask) + + if mode == UpdateTokenMode.INPUTS: + self._update_token_ids_inputs(token_ids, dllm_mask) + elif mode == UpdateTokenMode.PREFILL: + self._update_token_ids_prefill(token_ids, dllm_mask) + else: + self._update_token_ids_decode(token_ids, dllm_mask) + + if model_meta is not None: + self.model_meta = model_meta + + +class DLLMSequenceStrategy(SequenceStrategy): + + def __init__(self, block_size: int, dllm_mask_token: int) -> None: + self.block_size = block_size + self.dllm_mask_token = dllm_mask_token + + def make_sequence(self, + seq_id: int, + session: 'SchedulerSession', + sampling_param: 'SamplingParam' = None, + adapter_name: str = None, + migration_request: Optional[MigrationRequest] = None, + resp_cache: bool = False, + preserve_cache: bool = False) -> 'SchedulerSequenceDLLM': + """Make sequence.""" + return SchedulerSequenceDLLM(seq_id=seq_id, + session=session, + sampling_param=sampling_param, + adapter_name=adapter_name, + migration_request=migration_request, + resp_cache=resp_cache, + preserve_cache=preserve_cache) + + def update_running(self, running: SeqList, batched_outputs: BatchedOutputs, is_decoding: bool) -> None: + """Update running sequences.""" + next_token_ids = batched_outputs.next_token_ids + stopped = batched_outputs.stopped + stopped = stopped.tolist() + model_metas = batched_outputs.model_metas + if model_metas is None: + model_metas = [None] * len(running) + dllm_mask = batched_outputs.extra_outputs.dllm_mask + stop_pos = batched_outputs.stop_pos + + batch_size = len(running) + next_token_ids = next_token_ids.view(batch_size, -1).numpy() + dllm_mask = dllm_mask.view(batch_size, -1).numpy() + stop_pos = stop_pos.tolist() + update_mode = UpdateTokenMode.DECODE if is_decoding else UpdateTokenMode.PREFILL + for idx, token in enumerate(next_token_ids): + msg = running[idx] + stop = stopped[idx] + model_meta = model_metas[idx] + mask = dllm_mask[idx] + if msg.status != MessageStatus.LOCKED: + continue + + # fill token + msg.update_token_ids(token, dllm_mask=mask, model_meta=model_meta, mode=update_mode) + if stop: + msg.set_stop_pos(stop_pos[idx]) + msg.status = MessageStatus.TO_BE_MIGRATED if msg.preserve_cache else MessageStatus.STOPPED diff --git a/lmdeploy/pytorch/strategies/dllm/unmasking.py b/lmdeploy/pytorch/strategies/dllm/unmasking.py new file mode 100644 index 0000000000..7c24ac8d3f --- /dev/null +++ b/lmdeploy/pytorch/strategies/dllm/unmasking.py @@ -0,0 +1,117 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import torch +from torch.profiler import record_function + +from lmdeploy.pytorch import consts +from lmdeploy.pytorch.config import DLLMConfig, UnmaskingStrategy + +DLLM_MASKED = consts.DLLM_MASKED +DLLM_UNMASKED = consts.DLLM_UNMASKED +DLLM_CACHED = consts.DLLM_CACHED + + +class UnmaskingProcessor: + + def __init__(self, dllm_config: DLLMConfig): + self.dllm_config = dllm_config + + def _get_scores(self, logits: torch.Tensor, token_ids: torch.Tensor): + """Get scores.""" + scores = logits.softmax(dim=-1) + scores = scores.gather(-1, token_ids.unsqueeze(-1)).flatten() + return scores + + def _get_denoise_num(self): + """Get denoise num.""" + block_size = self.dllm_config.block_length + denoising_steps = self.dllm_config.denoising_steps + if denoising_steps is None: + denoising_steps = block_size + num = block_size // self.dllm_config.denoising_steps + num = max(1, min(num, block_size)) + return num + + def low_confidence_static(self, logits: torch.Tensor, token_ids: torch.Tensor, dllm_mask: torch.Tensor): + """static.""" + block_size = self.dllm_config.block_length + topk = self._get_denoise_num() + scores = self._get_scores(logits, token_ids) + is_masked = dllm_mask == DLLM_MASKED + scores = torch.where(is_masked, scores, scores.new_zeros((1, ))) + + scores = scores.view(-1, block_size) + dllm_mask = dllm_mask.view(-1, block_size) + _, indices = scores.topk(topk, dim=-1) + dllm_unmasked = dllm_mask.scatter(-1, indices, DLLM_UNMASKED) + + is_masked = is_masked.view_as(dllm_mask) + dllm_mask = torch.where(is_masked, dllm_unmasked, dllm_mask) + return dllm_mask.flatten() + + def low_confidence_dynamic(self, logits: torch.Tensor, token_ids: torch.Tensor, dllm_mask: torch.Tensor): + """dynamic.""" + block_size = self.dllm_config.block_length + threshold = self.dllm_config.confidence_threshold + scores = self._get_scores(logits, token_ids) + is_masked = dllm_mask == DLLM_MASKED + scores = torch.where(is_masked, scores, scores.new_zeros((1, ))) + + scores = scores.view(-1, block_size) + dllm_mask = dllm_mask.view(-1, block_size) + _, indices = scores.topk(1, dim=-1) + scores = scores.scatter(-1, indices, threshold) + + is_masked = is_masked.view_as(dllm_mask) + is_masked &= scores >= threshold + dllm_mask[is_masked] = DLLM_UNMASKED + return dllm_mask.flatten() + + def sequential(self, dllm_mask: torch.Tensor): + """sequential.""" + block_size = self.dllm_config.block_length + denoise_num = self._get_denoise_num() + dllm_mask = dllm_mask.view(-1, block_size) + is_masked = dllm_mask == DLLM_MASKED + + # get indices + indices = is_masked.int().argmax(dim=1) + ranges = torch.arange(0, denoise_num, device=indices.device, dtype=indices.dtype) + indices = indices[:, None] + ranges[None, :] + indices = indices % block_size + + dllm_unmasked = dllm_mask.clone() + dllm_unmasked = dllm_unmasked.scatter(-1, indices, DLLM_UNMASKED) + dllm_mask = torch.where(is_masked, dllm_unmasked, dllm_mask) + + return dllm_mask.flatten() + + @record_function('unmasking') + def __call__(self, logits: torch.Tensor, input_ids: torch.Tensor, token_ids: torch.Tensor, dllm_mask: torch.Tensor): + """call.""" + strategy = self.dllm_config.unmasking_strategy + if strategy is None: + return dllm_mask + + # reshape to [num_blocks, block_size] + block_size = self.dllm_config.block_length + dllm_mask = dllm_mask.unflatten(0, (-1, block_size)) + + is_same = (dllm_mask == dllm_mask[:, :1]).all(dim=1) + first_mask = dllm_mask[:, 0] + + # unmasked to cache + is_block_unmasked = is_same & (first_mask == DLLM_UNMASKED) + dllm_mask[is_block_unmasked] = DLLM_CACHED + + dllm_mask = dllm_mask.flatten() + token_ids = torch.where(dllm_mask != DLLM_MASKED, input_ids, token_ids) + if strategy == UnmaskingStrategy.LOW_CONFIDENCE_STATIC: + dllm_mask = self.low_confidence_static(logits, token_ids, dllm_mask) + elif strategy == UnmaskingStrategy.LOW_CONFIDENCE_DYNAMIC: + dllm_mask = self.low_confidence_dynamic(logits, token_ids, dllm_mask) + elif strategy == UnmaskingStrategy.SEQUENTIAL: + dllm_mask = self.sequential(dllm_mask) + else: + raise RuntimeError(f'strategy {strategy} not supported.') + + return dllm_mask, token_ids diff --git a/lmdeploy/pytorch/tools/utils.py b/lmdeploy/pytorch/tools/utils.py index 32b707b140..98d42772c0 100644 --- a/lmdeploy/pytorch/tools/utils.py +++ b/lmdeploy/pytorch/tools/utils.py @@ -132,6 +132,13 @@ def visualize_pipe_out(outputs, enable_meta: bool = True): from lmdeploy.messages import Response + try: + from termcolor import colored + except ImportError: + + def colored(text, color=None, on_color=None, attrs=None): + return text + if isinstance(outputs, Response): outputs = [outputs] try: @@ -139,24 +146,49 @@ def visualize_pipe_out(outputs, enable_meta: bool = True): except Exception: term_size = 100 - def _lined_print(msg: str, line_format: str = '-', full_line: bool = False): - print(msg) - if full_line: - columns = term_size - else: - columns = max(len(m) for m in msg.split('\n')) - print(line_format * columns) + border_color = 'cyan' + meta_color = 'light_grey' + number_color = 'green' + + def _print_title(title: str, color: str = border_color): + title_text = f' {title} ' + print(colored(f'【{title_text}】', color, attrs=['bold'])) + + def _print_section(title: str, content: str, color: str = border_color): + """Simple title and content printing.""" + _print_title(title, color) + print(content) + + def _print_meta(out: Response): + """Enhanced meta information display.""" + # Create a clean table-like format + finish_color = 'yellow' if out.finish_reason == 'stop' else 'red' + meta_content = [ + f"{colored('• Input Tokens:', meta_color)} {colored(out.input_token_len, number_color)}", + f"{colored('• Generated Tokens:', meta_color)} {colored(out.generate_token_len, number_color)}", + f"{colored('• Finish Reason:', meta_color)} {colored(out.finish_reason, finish_color)}" + ] + + lines = '\n'.join(meta_content) + lines += '\n' + _print_section('METADATA', lines, border_color) + + # Main loop + print(colored('━' * term_size, border_color)) outputs: List[Response] = outputs - term_line = '—' * term_size - print(term_line) for idx, out in enumerate(outputs): - _lined_print(f'output[{idx}]', '=') + header = f'OUTPUT [{idx + 1}/{len(outputs)}]' + header_formatted = colored(f'✦ {header}', 'light_magenta', attrs=['bold']) + print(header_formatted) + print() + if enable_meta: - _lined_print('meta', '-') - _lined_print( - f'input_token_len={out.input_token_len}\n' - f'generate_token_len={out.generate_token_len}\n' - f'finish_reason="{out.finish_reason}"', '—') - _lined_print('text', '-') - _lined_print(f'{out.text}', '—', full_line=True) + _print_meta(out) + + _print_section('TEXT', out.text, border_color) + + if idx < len(outputs) - 1: # Add separator when it's not the last output + print(colored('─' * (term_size), border_color, attrs=['dark'])) + else: + print(colored('━' * term_size, border_color)) diff --git a/tests/pytorch/kernel/test_flash_attention.py b/tests/pytorch/kernel/test_flash_attention.py index 2b6d17aa2f..5e1ad7cf1a 100644 --- a/tests/pytorch/kernel/test_flash_attention.py +++ b/tests/pytorch/kernel/test_flash_attention.py @@ -11,16 +11,18 @@ def _conti_input(data, q_seqlens): def _make_bias(q_seqlens, history_lens, neg_val, causal): + batch_size = q_seqlens.shape[0] kv_seqlens = q_seqlens + history_lens max_seq_len = q_seqlens.max().item() max_kv_len = kv_seqlens.max().item() if causal: - seq_ranges = [torch.arange(max_seq_len) for _ in q_seqlens] - for r, l in zip(seq_ranges, q_seqlens): - r[l:] = -max_kv_len - seq_ranges = torch.stack(seq_ranges, dim=0).cuda() - kv_ranges = [torch.arange(max_kv_len) for _ in kv_seqlens] - kv_ranges = torch.stack(kv_ranges, 0).cuda() + seq_ranges = torch.arange(max_seq_len).cuda() + seq_ranges = seq_ranges.repeat(batch_size, 1) + seq_ranges = torch.where(seq_ranges < q_seqlens[:, None], seq_ranges, -max_kv_len) + + kv_ranges = torch.arange(max_kv_len).cuda() + kv_ranges = kv_ranges.repeat(batch_size, 1) + mask = (kv_ranges[:, None, :] - seq_ranges[:, :, None] > history_lens[:, None, None]) return mask.float() * neg_val else: @@ -31,6 +33,27 @@ def _make_bias(q_seqlens, history_lens, neg_val, causal): return (~mask).float() * neg_val +def _make_block_sparse_bias(q_seqlens: torch.Tensor, history_lens: torch.Tensor, neg_val: float, + block_sparse_size: int): + """Make block sparse bias.""" + batch_size = q_seqlens.shape[0] + kv_seqlens = q_seqlens + history_lens + max_seq_len = q_seqlens.max().item() + max_kv_len = kv_seqlens.max().item() + + seq_ranges = torch.arange(max_seq_len).cuda() + seq_ranges = seq_ranges // block_sparse_size * block_sparse_size + seq_ranges = seq_ranges.repeat(batch_size, 1) + seq_ranges = torch.where(seq_ranges < q_seqlens[:, None], seq_ranges, -max_kv_len) + + kv_ranges = torch.arange(max_kv_len).cuda() + kv_ranges = kv_ranges // block_sparse_size * block_sparse_size + kv_ranges = kv_ranges.repeat(batch_size, 1) + + mask = (kv_ranges[:, None, :] - seq_ranges[:, :, None] > history_lens[:, None, None]) + return mask.float() * neg_val + + def _naive_attention(batched_q, batched_kv, bias, sinks=None): batched_k, batched_v = batched_kv @@ -283,3 +306,42 @@ def test_sinks(self, conti_q, conti_kv, q_start_loc, q_seqlens, kv_start_loc, kv sinks=sinks, causal=causal) torch.testing.assert_close(out, conti_sink_gt, atol=1e-3, rtol=1e-5) + + # block sparse attention + @pytest.fixture + def block_sparse_size(self): + yield 4 + + @pytest.fixture + def block_sparse_mask(self, q_seqlens, history_lens, block_sparse_size): + neg_val = -1e30 + yield _make_block_sparse_bias(q_seqlens, history_lens, neg_val, block_sparse_size) + + @pytest.fixture + def block_sparse_gt(self, batched_q, batched_kv, block_sparse_mask): + yield _naive_attention(batched_q, batched_kv, block_sparse_mask) + + @pytest.mark.parametrize('head_dim_k', [32], indirect=True) + @pytest.mark.parametrize('head_dim_v', [32], indirect=True) + @pytest.mark.parametrize('num_heads_q', [8], indirect=True) + @pytest.mark.parametrize('num_heads_k', [2], indirect=True) + @pytest.mark.parametrize(['q_seqlens', 'history_lens'], [([16, 32], [64, 8])], indirect=True) + def test_block_sparse_attention(self, conti_q, conti_kv, q_start_loc, q_seqlens, kv_start_loc, kv_seqlens, + head_dim_v, block_sparse_size, block_sparse_gt): + from lmdeploy.pytorch.kernels.cuda.flashattention import flash_attention_fwd + max_seq_len = q_seqlens.max().item() + + conti_k, conti_v = conti_kv + out = conti_q.new_empty(*conti_q.shape[:-1], head_dim_v) + flash_attention_fwd(conti_q, + conti_k, + conti_v, + out, + q_start_loc=q_start_loc, + q_seqlens=q_seqlens, + kv_start_loc=kv_start_loc, + kv_seqlens=kv_seqlens, + max_seqlen=max_seq_len, + block_sparse_size=block_sparse_size) + gt = _conti_input(block_sparse_gt, q_seqlens) + torch.testing.assert_close(out, gt, atol=1e-3, rtol=1e-5) diff --git a/tests/pytorch/kernel/test_paged_attention.py b/tests/pytorch/kernel/test_paged_attention.py index a6289fc626..f4268cacdc 100644 --- a/tests/pytorch/kernel/test_paged_attention.py +++ b/tests/pytorch/kernel/test_paged_attention.py @@ -10,20 +10,42 @@ def _conti_input(data, seq_lens): return data -def _make_bias(seq_lens, history_lens, neg_val): - full_seq_lens = seq_lens + history_lens - max_seq_len = seq_lens.max().item() - max_full_len = full_seq_lens.max().item() - seq_ranges = [torch.arange(max_seq_len) for _ in seq_lens] - for r, l in zip(seq_ranges, seq_lens): - r[l:] = -max_full_len - seq_ranges = torch.stack(seq_ranges, dim=0).cuda() - kv_ranges = [torch.arange(max_full_len) for _ in full_seq_lens] - kv_ranges = torch.stack(kv_ranges, 0).cuda() +def _make_bias(q_seqlens, history_lens, neg_val): + batch_size = q_seqlens.shape[0] + full_seq_lens = q_seqlens + history_lens + max_seq_len = q_seqlens.max().item() + max_kv_len = full_seq_lens.max().item() + seq_ranges = torch.arange(max_seq_len).cuda() + seq_ranges = seq_ranges.repeat(batch_size, 1) + seq_ranges = torch.where(seq_ranges < q_seqlens[:, None], seq_ranges, -max_kv_len) + + kv_ranges = torch.arange(max_kv_len).cuda() + kv_ranges = kv_ranges.repeat(batch_size, 1) mask = kv_ranges[:, None, :] - seq_ranges[:, :, None] > history_lens[:, None, None] return mask.float() * neg_val +def _make_block_sparse_bias(q_seqlens: torch.Tensor, history_lens: torch.Tensor, neg_val: float, + block_sparse_size: int): + """Make block sparse bias.""" + batch_size = q_seqlens.shape[0] + kv_seqlens = q_seqlens + history_lens + max_seq_len = q_seqlens.max().item() + max_kv_len = kv_seqlens.max().item() + + seq_ranges = torch.arange(max_seq_len).cuda() + seq_ranges = seq_ranges // block_sparse_size * block_sparse_size + seq_ranges = seq_ranges.repeat(batch_size, 1) + seq_ranges = torch.where(seq_ranges < q_seqlens[:, None], seq_ranges, -max_kv_len) + + kv_ranges = torch.arange(max_kv_len).cuda() + kv_ranges = kv_ranges // block_sparse_size * block_sparse_size + kv_ranges = kv_ranges.repeat(batch_size, 1) + + mask = (kv_ranges[:, None, :] - seq_ranges[:, :, None] > history_lens[:, None, None]) + return mask.float() * neg_val + + def _make_blocked_cache(batched_k, batched_v, seq_lens, @@ -119,7 +141,7 @@ def _make_cu_seqlens(seqlens): return output -class TestPagedAttention: +class TestPagedAttentionBase: @pytest.fixture def dtype(self): @@ -154,18 +176,22 @@ def history_lens(self, request): yield torch.tensor(request.param, device='cuda') @pytest.fixture - def seq_lens(self, history_lens): - yield torch.ones_like(history_lens) + def seq_len(self): + yield 1 + + @pytest.fixture + def seq_lens(self, seq_len, history_lens): + yield torch.ones_like(history_lens) * seq_len @pytest.fixture - def kv_seqlens(self, history_lens): - yield 1 + history_lens + def kv_seqlens(self, seq_lens, history_lens): + yield seq_lens + history_lens @pytest.fixture - def batched_q(self, kv_seqlens, num_heads_q, feat_dim, dtype): + def batched_q(self, seq_len, kv_seqlens, num_heads_q, feat_dim, dtype): torch.manual_seed(123) batch_size = len(kv_seqlens) - inputs = torch.rand(batch_size, 1, num_heads_q, feat_dim, dtype=dtype, device='cuda') + inputs = torch.rand(batch_size, seq_len, num_heads_q, feat_dim, dtype=dtype, device='cuda') yield inputs @pytest.fixture @@ -178,8 +204,7 @@ def batched_kv(self, kv_seqlens, num_heads_k, feat_dim, feat_dim_v, dtype): yield k, v @pytest.fixture - def conti_q(self, kv_seqlens, batched_q): - seq_lens = torch.ones_like(kv_seqlens) + def conti_q(self, seq_lens, batched_q): yield _conti_input(batched_q, seq_lens) @pytest.fixture @@ -225,7 +250,10 @@ def gt(self, batched_q, batched_kv, mask): def conti_gt(self, gt, seq_lens): yield _conti_input(gt, seq_lens) - @pytest.mark.parametrize('feat_dim', [48, 32], indirect=True) + +class TestPagedAttention(TestPagedAttentionBase): + + @pytest.mark.parametrize('feat_dim', [32, 32], indirect=True) @pytest.mark.parametrize('feat_dim_v', [32], indirect=True) @pytest.mark.parametrize(['num_heads_q', 'num_heads_k'], [(128, 2), (8, 2), (2, 2)], indirect=True) @pytest.mark.parametrize('history_lens', [(50, 40, 30, 20)], indirect=True) @@ -288,7 +316,7 @@ def test_window_attention(self, conti_q, blocked_kv, block_offsets, history_lens torch.testing.assert_close(out, window_gt, atol=1e-3, rtol=1e-5) -class TestPagedAttentionSink(TestPagedAttention): +class TestPagedAttentionSink(TestPagedAttentionBase): @pytest.fixture def sinks(self, num_heads_q, dtype): @@ -308,7 +336,8 @@ def conti_sink_gt(self, sink_gt, seq_lens): @pytest.mark.parametrize('history_lens', [(50, 40, 30, 20)], indirect=True) @pytest.mark.parametrize('block_size', [16], indirect=True) @pytest.mark.parametrize('layout', ['bshd'], indirect=True) - def test_sink(self, conti_q, blocked_kv, block_offsets, history_lens, feat_dim_v, layout, sinks, conti_sink_gt): + def test_paged_attention(self, conti_q, blocked_kv, block_offsets, history_lens, feat_dim_v, layout, sinks, + conti_sink_gt): from lmdeploy.pytorch.kernels.cuda import paged_attention_fwd kv_seq_lens = 1 + history_lens @@ -450,3 +479,46 @@ class TestPagedAttentionInt4(TestPagedAttentionInt8): @pytest.fixture def nbits(self): yield 4 + + +class TestPagedAttentionBlockDecoding(TestPagedAttentionBase): + + @pytest.fixture + def seq_len(self): + yield 4 + + @pytest.fixture + def mask(self, seq_lens, history_lens, seq_len): + neg_val = -1e30 + yield _make_block_sparse_bias(seq_lens, history_lens, neg_val, seq_len) + + @pytest.fixture + def gt(self, batched_q, batched_kv, mask): + yield _naive_attention(batched_q, batched_kv, mask) + + @pytest.fixture + def conti_gt(self, gt, seq_lens): + yield _conti_input(gt, seq_lens) + + @pytest.mark.parametrize('feat_dim', [48, 32], indirect=True) + @pytest.mark.parametrize('feat_dim_v', [32], indirect=True) + @pytest.mark.parametrize(['num_heads_q', 'num_heads_k'], [(128, 2), (8, 2), (2, 2)], indirect=True) + @pytest.mark.parametrize('history_lens', [(52, 40, 32, 20)], indirect=True) + @pytest.mark.parametrize('block_size', [16], indirect=True) + @pytest.mark.parametrize('layout', ['bshd'], indirect=True) + def test_paged_attention(self, conti_q, blocked_kv, block_offsets, seq_lens, history_lens, feat_dim_v, layout, + conti_gt): + from lmdeploy.pytorch.kernels.cuda import paged_attention_fwd + kv_seq_lens = seq_lens + history_lens + + blocked_k, blocked_v = blocked_kv + out = conti_q.new_empty(*conti_q.shape[:-1], feat_dim_v) + + paged_attention_fwd(conti_q, + blocked_k, + blocked_v, + out, + block_offsets=block_offsets, + kv_seqlens=kv_seq_lens, + kv_layout=layout) + torch.testing.assert_close(out, conti_gt, atol=1e-3, rtol=1e-5) diff --git a/tests/pytorch/paging/test_block_manager.py b/tests/pytorch/paging/test_block_manager.py index cdf45d5d6d..f74b6548cf 100644 --- a/tests/pytorch/paging/test_block_manager.py +++ b/tests/pytorch/paging/test_block_manager.py @@ -2,7 +2,7 @@ import pytest import torch -from lmdeploy.pytorch.messages import SchedulerSession +from lmdeploy.pytorch.messages import SchedulerSession, SequenceManager, SequenceMeta from lmdeploy.pytorch.paging.block_manager import DefaultBlockManager, WindowBlockManager from lmdeploy.pytorch.paging.block_manager.base_block_manager import LogicalAllocator @@ -89,8 +89,16 @@ def num_gpu_blocks(self): def block_mgr(self, num_cpu_blocks, num_gpu_blocks): yield DefaultBlockManager(num_cpu_blocks, num_gpu_blocks) - def test_alloc(self, block_mgr, block_size, num_gpu_blocks): - sess = SchedulerSession(0, block_size) + @pytest.fixture + def seq_manager(self, block_size): + from lmdeploy.pytorch.strategies.ar.sequence import ARSequenceStrategy + strategy = ARSequenceStrategy() + seq_meta = SequenceMeta(block_size, strategy=strategy) + yield SequenceManager(seq_meta) + + def test_alloc(self, block_mgr, seq_manager, num_gpu_blocks): + sess = SchedulerSession(0, seq_manager) + block_size = sess.seq_meta.block_size # test alloc token_ids = torch.tensor([1]) @@ -113,9 +121,10 @@ def test_alloc(self, block_mgr, block_size, num_gpu_blocks): msg = sess.add_sequence(token_ids) assert not block_mgr.can_allocate(msg) - def test_num_required_blocks(self, block_mgr, block_size, num_gpu_blocks): + def test_num_required_blocks(self, block_mgr, seq_manager, num_gpu_blocks): from lmdeploy.pytorch.messages import InputEmbeddings - sess = SchedulerSession(0, block_size) + sess = SchedulerSession(0, seq_manager) + block_size = sess.seq_meta.block_size token_ids = torch.tensor([1]) msg = sess.add_sequence(token_ids) @@ -133,8 +142,9 @@ def test_num_required_blocks(self, block_mgr, block_size, num_gpu_blocks): num_required = block_mgr.num_required_blocks(msg) assert num_required == 3 - def test_append_slot(self, block_mgr, block_size, num_gpu_blocks): - sess = SchedulerSession(0, block_size) + def test_append_slot(self, block_mgr, seq_manager, num_gpu_blocks): + sess = SchedulerSession(0, seq_manager) + block_size = sess.seq_meta.block_size # test append token_ids = torch.tensor([1]) @@ -158,8 +168,9 @@ def test_append_slot(self, block_mgr, block_size, num_gpu_blocks): assert len(block_table) == 2 assert block_mgr.get_num_free_gpu_blocks() == num_gpu_blocks - 2 - def test_swap(self, block_mgr, block_size, num_gpu_blocks): - sess = SchedulerSession(0, block_size) + def test_swap(self, block_mgr, seq_manager, num_gpu_blocks): + sess = SchedulerSession(0, seq_manager) + block_size = sess.seq_meta.block_size token_ids = torch.tensor([1] * (block_size + 1)) msg = sess.add_sequence(token_ids) @@ -215,12 +226,20 @@ def num_cpu_blocks(self): def num_gpu_blocks(self): yield 4 + @pytest.fixture + def seq_manager(self, block_size): + from lmdeploy.pytorch.strategies.ar.sequence import ARSequenceStrategy + strategy = ARSequenceStrategy() + seq_meta = SequenceMeta(block_size, strategy=strategy) + yield SequenceManager(seq_meta) + @pytest.fixture def block_mgr(self, num_cpu_blocks, num_gpu_blocks, window_size): yield WindowBlockManager(num_cpu_blocks, num_gpu_blocks, window_size) - def test_alloc(self, block_mgr, block_size, num_gpu_blocks): - sess = SchedulerSession(0, block_size) + def test_alloc(self, block_mgr, seq_manager, num_gpu_blocks): + sess = SchedulerSession(0, seq_manager) + block_size = sess.seq_meta.block_size # test alloc token_ids = torch.tensor([1]) @@ -243,8 +262,8 @@ def test_alloc(self, block_mgr, block_size, num_gpu_blocks): msg = sess.add_sequence(token_ids) assert not block_mgr.can_allocate(msg) - def test_win_alloc(self, block_mgr, block_size, num_gpu_blocks, window_size): - sess = SchedulerSession(0, block_size) + def test_win_alloc(self, block_mgr, seq_manager, num_gpu_blocks, window_size): + sess = SchedulerSession(0, seq_manager) # 2 win block token_ids = torch.tensor([1] * window_size) diff --git a/tests/pytorch/paging/test_block_trie.py b/tests/pytorch/paging/test_block_trie.py index 06829f4c75..7d20c96dab 100644 --- a/tests/pytorch/paging/test_block_trie.py +++ b/tests/pytorch/paging/test_block_trie.py @@ -2,7 +2,7 @@ import pytest from lmdeploy.pytorch.config import CacheConfig -from lmdeploy.pytorch.messages import SchedulerSession +from lmdeploy.pytorch.messages import SchedulerSession, SequenceManager, SequenceMeta from lmdeploy.pytorch.paging.block_manager import build_block_manager from lmdeploy.pytorch.paging.block_trie import BlockTrie @@ -37,9 +37,17 @@ def block_mgr(self, cache_config): def block_trie(self, cache_config, block_mgr): yield BlockTrie(cache_config, block_mgr) - def test_allocate(self, block_trie, block_mgr, block_size): + @pytest.fixture + def seq_manager(self, block_size): + from lmdeploy.pytorch.strategies.ar.sequence import ARSequenceStrategy + strategy = ARSequenceStrategy() + seq_meta = SequenceMeta(block_size, strategy=strategy) + yield SequenceManager(seq_meta) + + def test_allocate(self, block_trie, block_mgr, seq_manager): allocator = block_trie.allocator - sess = SchedulerSession(0, block_size) + sess = SchedulerSession(0, seq_manager) + block_size = sess.seq_meta.block_size token_ids = ([1] * block_size + [2] * block_size) token_ids += [3] * (block_size // 2) seq = sess.add_sequence(token_ids) @@ -75,9 +83,10 @@ def test_allocate(self, block_trie, block_mgr, block_size): assert node in block_trie.leaves assert len(block_trie.leaves) == 1 - def test_match(self, block_trie, block_mgr, block_size): + def test_match(self, block_trie, block_mgr, seq_manager): allocator = block_trie.allocator - sess = SchedulerSession(0, block_size) + sess = SchedulerSession(0, seq_manager) + block_size = sess.seq_meta.block_size # initialize cache token_ids = ([1] * block_size + [2] * block_size) @@ -112,9 +121,10 @@ def test_match(self, block_trie, block_mgr, block_size): ref_cnt = allocator.get_ref_count(logical_blocks.get_real_blocks()) assert np.array_equal(ref_cnt, [4, 3]) - def test_evict(self, block_trie, block_size, num_gpu_blocks): + def test_evict(self, block_trie, seq_manager, num_gpu_blocks): block_mgr = block_trie.block_manager - sess = SchedulerSession(0, block_size) + sess = SchedulerSession(0, seq_manager) + block_size = sess.seq_meta.block_size token_ids = ([1] * block_size * (num_gpu_blocks - 1)) token_ids += [2] * (block_size // 2) seq = sess.add_sequence(token_ids) diff --git a/tests/pytorch/paging/test_scheduler.py b/tests/pytorch/paging/test_scheduler.py index f14ab8249e..a0acf5f054 100644 --- a/tests/pytorch/paging/test_scheduler.py +++ b/tests/pytorch/paging/test_scheduler.py @@ -2,7 +2,7 @@ import torch from lmdeploy.pytorch.config import CacheConfig, SchedulerConfig -from lmdeploy.pytorch.messages import MessageStatus +from lmdeploy.pytorch.messages import MessageStatus, SequenceMeta from lmdeploy.pytorch.paging.scheduler import Scheduler @@ -32,8 +32,14 @@ def scheduler_config(self): yield SchedulerConfig(max_batches=4, max_session_len=128, max_request_output_len=64, eviction_type='recompute') @pytest.fixture - def scheduler(self, cache_config, scheduler_config): - yield Scheduler(scheduler_config=scheduler_config, cache_config=cache_config) + def seq_meta(self, block_size): + from lmdeploy.pytorch.strategies.ar.sequence import ARSequenceStrategy + strategy = ARSequenceStrategy() + yield SequenceMeta(block_size, strategy=strategy) + + @pytest.fixture + def scheduler(self, cache_config, scheduler_config, seq_meta): + yield Scheduler(scheduler_config=scheduler_config, cache_config=cache_config, seq_meta=seq_meta) def test_schedule_base(self, scheduler, block_size, num_gpu_blocks): block_manager = scheduler.block_manager