From 07c667c0b49e5730524ff12e4ecd1ac0fa4e4c3d Mon Sep 17 00:00:00 2001 From: grimoire Date: Tue, 26 Aug 2025 14:30:45 +0800 Subject: [PATCH 01/29] refactor SchedulerSequence --- lmdeploy/pytorch/engine/engine.py | 17 ++--- lmdeploy/pytorch/engine/logits_process.py | 2 +- lmdeploy/pytorch/engine/model_agent.py | 2 - lmdeploy/pytorch/messages.py | 71 +++++++++---------- .../block_manager/default_block_manager.py | 2 +- .../block_manager/window_block_manager.py | 4 +- lmdeploy/pytorch/paging/scheduler.py | 15 ++-- tests/pytorch/paging/test_block_manager.py | 41 +++++++---- tests/pytorch/paging/test_block_trie.py | 22 ++++-- 9 files changed, 95 insertions(+), 81 deletions(-) diff --git a/lmdeploy/pytorch/engine/engine.py b/lmdeploy/pytorch/engine/engine.py index 9c0de8fbe6..16ca8aac6a 100644 --- a/lmdeploy/pytorch/engine/engine.py +++ b/lmdeploy/pytorch/engine/engine.py @@ -586,7 +586,7 @@ def __update_max_new_tokens(msg): max_session_len = self.max_session_len sampling_param = msg.sampling_param max_new_tokens = sampling_param.max_new_tokens - num_all_tokens = msg.num_all_tokens() + num_all_tokens = msg.num_all_ids if max_new_tokens + num_all_tokens > max_session_len: logger.warning( f'session[{msg.session_id}]: num tokens is larger than max session len {max_session_len}. ' @@ -603,14 +603,12 @@ def __update_max_new_tokens(msg): sess = scheduler.sessions[session_id] # TODO: support 1 session n sequence sampling_param = req.data['sampling_param'] - return_logits = sampling_param.out_logits if len(sess.sequences) == 0: migration_request = req.data.get('migration_request') assert len(req.data['token_ids']) > 0, ('Empty input is not allowed.') sess.add_sequence(req.data['token_ids'], sampling_param=sampling_param, adapter_name=req.data['adapter_name'], - return_logits=return_logits, multimodals=req.data.get('input_multimodals'), input_embeddings=req.data.get('input_embeddings', ), migration_request=migration_request, @@ -630,9 +628,7 @@ def __update_max_new_tokens(msg): embeddings=req.data.get('input_embeddings'), append_tokens=True, ) - msg.num_new_tokens = 0 msg.sampling_param = sampling_param - msg.return_logits = return_logits msg.status = MessageStatus.WAITING __update_max_new_tokens(msg) @@ -670,10 +666,11 @@ def __get_vlm_embeddings(): ] input_embedding_indexing = torch.zeros((batch_size, max_q_seq_length), dtype=torch.bool) for msg_id, msg in enumerate(messages): + num_history_ids = msg.num_history_ids for emb in msg.input_embeddings: # make slice index relative to embeddings - emb_start = emb.start - msg.history_len - emb_end = emb.end - msg.history_len + emb_start = emb.start - num_history_ids + emb_end = emb.end - num_history_ids input_embedding_indexing[msg_id][emb_start:emb_end] = True return (input_embeddings, input_embedding_indexing, input_embedding_ranges) @@ -727,7 +724,7 @@ def create_model_inputs(self, messages: SeqList, is_prefill: bool): """ batch_size = len(messages) # history lengths - history_lengths = torch.tensor([msg.history_len for msg in messages]) + history_lengths = torch.tensor([msg.num_history_ids for msg in messages]) # input ids token_ids = [msg.token_ids for msg in messages] @@ -740,8 +737,8 @@ def create_model_inputs(self, messages: SeqList, is_prefill: bool): seq_length = torch.tensor(seq_length, dtype=torch.long) max_q_seqlen = seq_length.max().item() else: - seq_length = torch.ones(batch_size, dtype=torch.long) max_q_seqlen = 1 + seq_length = torch.full((batch_size, ), max_q_seqlen, dtype=torch.long) kv_seqlens = seq_length + history_lengths max_kv_seqlen = kv_seqlens.max().item() sum_kv_seqlen = kv_seqlens.sum().item() @@ -804,7 +801,6 @@ def update_running(self, running: SeqList, next_token_ids: torch.Tensor, stopped # fill token msg.update_token_ids(update_token, model_meta=model_meta) - msg.num_new_tokens += 1 if stop: msg.status = MessageStatus.TO_BE_MIGRATED if msg.preserve_cache else MessageStatus.STOPPED @@ -820,7 +816,6 @@ def update_running_migration(self, running: SeqList, next_token_ids: np.ndarray, # fill token msg.update_token_ids(update_token, model_meta=model_meta) - msg.num_new_tokens += 1 if stop: update_token = _EMPTY_TOKEN msg.update_token_ids(update_token, model_meta=model_meta) diff --git a/lmdeploy/pytorch/engine/logits_process.py b/lmdeploy/pytorch/engine/logits_process.py index 214f83256e..7413e158f4 100644 --- a/lmdeploy/pytorch/engine/logits_process.py +++ b/lmdeploy/pytorch/engine/logits_process.py @@ -154,7 +154,7 @@ def __gather_params(): top_k[idx] = param.top_k top_p[idx] = param.top_p min_p[idx] = param.min_p - random_offsets[idx] = seq.random_offsets + random_offsets[idx] = seq.num_all_ids response_formats[idx] = param.response_format if param.random_seed is not None: random_seeds[idx] = param.random_seed & 0xffffffff diff --git a/lmdeploy/pytorch/engine/model_agent.py b/lmdeploy/pytorch/engine/model_agent.py index f6f9b6e4ca..e8c254305f 100644 --- a/lmdeploy/pytorch/engine/model_agent.py +++ b/lmdeploy/pytorch/engine/model_agent.py @@ -609,8 +609,6 @@ def __update_inputs(next_token_ids, model_metas): all_ids = torch.cat([all_ids, next_token_ids[:, None].to(all_ids.device)], 1) if guided_input_ids is not None: guided_input_ids = torch.cat([guided_input_ids, next_token_ids[:, None].to(guided_input_ids.device)], 1) - if sampling_inputs.random_offsets is not None: - sampling_inputs.random_offsets += 1 @asynccontextmanager async def __prepare_dp(): diff --git a/lmdeploy/pytorch/messages.py b/lmdeploy/pytorch/messages.py index c21db48eab..027fd6839a 100644 --- a/lmdeploy/pytorch/messages.py +++ b/lmdeploy/pytorch/messages.py @@ -153,29 +153,33 @@ class MessageStatus(enum.Enum): MIGRATION_DONE = enum.auto() -_SEQ_COUNT = 0 - - -def _new_msg_id(): - """Get a new message id.""" - global _SEQ_COUNT - seq_id = _SEQ_COUNT - _SEQ_COUNT += 1 - return seq_id +SeqMap = Dict[int, 'SchedulerSequence'] -SeqMap = Dict[int, 'SchedulerSequence'] +@dataclass +class SequenceMeta: + """Meta data shared by all sequence.""" + block_size: int + model_type: str = 'llm' class SequenceManager: """Sequence manager.""" - def __init__(self) -> None: + def __init__(self, seq_meta: SequenceMeta) -> None: self._seq_map: SeqMap = dict() self._status_seq_map: Dict[MessageStatus, SeqMap] = dict() for status in MessageStatus: self._status_seq_map[status] = dict() + self.seq_meta = seq_meta + self._seq_count = 0 + + def _new_seq_id(self): + seq_id = self._seq_count + self._seq_count += 1 + return seq_id + def get_all_sequences(self): """Get all sequences.""" return self._seq_map.values() @@ -221,9 +225,9 @@ def update_sequence_status(self, seq: 'SchedulerSequence', new_status: MessageSt class SchedulerSession: """Scheduler session.""" - def __init__(self, session_id: int, block_size: int, seq_manager: SequenceManager = None) -> None: + def __init__(self, session_id: int, seq_manager: SequenceManager) -> None: self.session_id = session_id - self.block_size = block_size + self.seq_meta = seq_manager.seq_meta self.status: MessageStatus = MessageStatus.RUNNING self.sequences: SeqMap = dict() self.seq_manager = seq_manager @@ -232,7 +236,6 @@ def add_sequence(self, token_ids: Tensor, sampling_param: SamplingParam = None, adapter_name: str = None, - return_logits: bool = False, multimodals: MultiModalInputs = None, input_embeddings: List[InputEmbeddings] = None, migration_request: Optional[MigrationRequest] = None, @@ -248,8 +251,9 @@ def add_sequence(self, if sampling_param is None: sampling_param = SamplingParam() + seq_id = self.seq_manager._new_seq_id() seq = SchedulerSequence( - seq_id=_new_msg_id(), + seq_id=seq_id, session=self, history_cache=HistoryTokenIds(token_ids), num_new_tokens=0, @@ -258,22 +262,19 @@ def add_sequence(self, arrive_time=time.perf_counter(), history_embeddings=HistoryEmbeddings(input_embeddings), history_multimodals=HistoryMultiModals(multimodals), - return_logits=return_logits, migration_request=migration_request, resp_cache=resp_cache, preserve_cache=preserve_cache, ) self.sequences[seq.seq_id] = seq - if self.seq_manager is not None: - self.seq_manager.add_sequence(seq) + self.seq_manager.add_sequence(seq) return seq def remove_sequence(self, seq: 'SchedulerSequence'): """Remove sequence.""" assert seq.seq_id in self.sequences self.sequences.pop(seq.seq_id) - if self.seq_manager is not None: - self.seq_manager.remove_sequence(seq) + self.seq_manager.remove_sequence(seq) def _div_up(x, n): @@ -462,8 +463,6 @@ class SchedulerSequence: adapter_name: str = None arrive_time: float = 0.0 meta: Any = None - return_logits: bool = False - random_offsets: int = 0 _status: MessageStatus = field(default=MessageStatus.WAITING, init=False) num_ignored_history: int = 0 model_meta: Dict[str, Any] = None @@ -485,16 +484,12 @@ def __post_init__(self): self._num_history_cross: int = 0 self._num_cross: int = self.history_multimodals.get_encoder_len(0, self._num_token_ids) + self._seq_meta: SequenceMeta = self.session.seq_meta @property def block_size(self) -> int: """Block size.""" - return self.session.block_size - - @property - def history_len(self) -> int: - """Get history length.""" - return self._num_history_ids + return self._seq_meta.block_size @property def history_image_num(self) -> int: @@ -514,7 +509,7 @@ def session_id(self) -> int: @property def token_ids(self) -> np.ndarray: """Token ids.""" - start = self.history_len + start = self.num_history_ids end = start + self._num_token_ids return self.history_cache._token_ids[start:end] @@ -528,7 +523,7 @@ def input_embeddings(self) -> List[InputEmbeddings]: @property def history_ids(self) -> np.ndarray: """History ids.""" - return self.history_cache._token_ids[:self.history_len] + return self.history_cache._token_ids[:self.num_history_ids] @property def all_ids(self) -> np.ndarray: @@ -551,7 +546,7 @@ def num_images(self): @property def num_all_ids(self): """Num all tokens.""" - return self.history_len + self._num_token_ids + return self._num_history_ids + self._num_token_ids @property def num_cross(self): @@ -577,15 +572,15 @@ def seq_manager(self) -> SequenceManager: def status(self): return self._status + @property + def return_logits(self): + return self.sampling_param.out_logits + @status.setter def status(self, value: MessageStatus): self.seq_manager.update_sequence_status(self, value) self._status = value - def num_all_tokens(self): - """Num all tokens.""" - return self.num_all_ids - def num_all_cross_tokens(self): """Num of all cross tokens.""" return self._num_cross + self._num_history_cross @@ -640,10 +635,12 @@ def update_token_ids(self, token_ids = token_ids[None] if append_tokens: self._num_token_ids += len(token_ids) + self.num_new_tokens = 0 else: - self._num_token_ids = len(token_ids) + num_token_ids = len(token_ids) + self._num_token_ids = num_token_ids + self.num_new_tokens += num_token_ids self.history_cache.append(token_ids) - self.random_offsets += 1 self.arrive_time = time.perf_counter() def set_step(self, step: int): diff --git a/lmdeploy/pytorch/paging/block_manager/default_block_manager.py b/lmdeploy/pytorch/paging/block_manager/default_block_manager.py index cc41187900..3348f2f8a5 100644 --- a/lmdeploy/pytorch/paging/block_manager/default_block_manager.py +++ b/lmdeploy/pytorch/paging/block_manager/default_block_manager.py @@ -25,7 +25,7 @@ class DefaultBlockManager(BaseBlockManager): @classmethod def num_required_blocks(cls, obj: SchedulerSequence, prealloc_size: int = 0): """Get num required blocks.""" - num_tokens = obj.num_all_tokens() + prealloc_size + num_tokens = obj.num_all_ids + prealloc_size # cross tokens num_cross = obj.num_all_cross_tokens() diff --git a/lmdeploy/pytorch/paging/block_manager/window_block_manager.py b/lmdeploy/pytorch/paging/block_manager/window_block_manager.py index 3f5ca9605e..1d91c020cb 100644 --- a/lmdeploy/pytorch/paging/block_manager/window_block_manager.py +++ b/lmdeploy/pytorch/paging/block_manager/window_block_manager.py @@ -10,10 +10,10 @@ def _num_blocks_to_drop(seq: SchedulerSequence, window_size: int): """Num blocks to free.""" - if seq.history_len <= window_size: + history_len = seq.num_history_ids + if seq.num_history_ids <= window_size: return 0 block_size = seq.block_size - history_len = seq.history_len num_blocks = len(seq.logical_blocks) win_start_block_id = (history_len - window_size) // block_size win_end_block_id = (history_len - 1) // block_size diff --git a/lmdeploy/pytorch/paging/scheduler.py b/lmdeploy/pytorch/paging/scheduler.py index 0be35ab9be..bf36c74321 100644 --- a/lmdeploy/pytorch/paging/scheduler.py +++ b/lmdeploy/pytorch/paging/scheduler.py @@ -9,7 +9,7 @@ from lmdeploy.utils import get_logger, logging_timer from ..config import CacheConfig, SchedulerConfig -from ..messages import MessageStatus, SchedulerSequence, SchedulerSession, SequenceManager +from ..messages import MessageStatus, SchedulerSequence, SchedulerSession, SequenceManager, SequenceMeta from .block_manager import build_block_manager from .block_trie import BlockTrie @@ -50,7 +50,8 @@ def __init__(self, scheduler_config: SchedulerConfig, cache_config: CacheConfig) self.eviction_helper = self.build_eviction_helper(self.scheduler_config.eviction_type) - self.seq_manager = SequenceManager() + seq_meta = SequenceMeta(self.cache_config.block_size) + self.seq_manager = SequenceManager(seq_meta) @property def waiting(self): @@ -121,7 +122,7 @@ def add_session(self, session_id: int): session_id (int): New session id. """ assert session_id not in self.sessions - session = SchedulerSession(session_id, self.cache_config.block_size, seq_manager=self.seq_manager) + session = SchedulerSession(session_id, seq_manager=self.seq_manager) self.sessions[session_id] = session return session @@ -179,7 +180,7 @@ def _reorder_migrating(): return running_migration @logging_timer('SchedulePrefilling', logger) - def _schedule_prefill(self): + def _schedule_prefill(self, prealloc_size: int = 0): """Schedule for prefilling.""" max_batches = self.scheduler_config.max_batches - self.num_running() - self.num_locked() @@ -203,7 +204,7 @@ def __evict_for_seq(seq: SchedulerSequence, waiting): hanging = reversed(self.hanging) waiting = reversed(waiting) evictable = list(chain(hanging, waiting)) - return eviction_helper.evict_for_seq(seq, evictable, 0) + return eviction_helper.evict_for_seq(seq, evictable, prealloc_size) def _reorder_waiting(): """Reorder waiting.""" @@ -226,7 +227,7 @@ def _reorder_waiting(): break # allocate session memory - self.block_manager.allocate(seq) + self.block_manager.allocate(seq, prealloc_size) _to_running(seq) seq.record_event(EventType.SCHEDULED) @@ -287,7 +288,7 @@ def __evict_for_seq(seq: SchedulerSequence, num_required_blocks: int): def schedule(self, is_prefill: bool, prealloc_size: int = 0): """Schedule inputs for next steps.""" if is_prefill: - output = self._schedule_prefill() + output = self._schedule_prefill(prealloc_size) else: output = self._schedule_decoding(prealloc_size) running, swap_in_map, swap_out_map, copy_map = output diff --git a/tests/pytorch/paging/test_block_manager.py b/tests/pytorch/paging/test_block_manager.py index cdf45d5d6d..6f7583c9a9 100644 --- a/tests/pytorch/paging/test_block_manager.py +++ b/tests/pytorch/paging/test_block_manager.py @@ -2,7 +2,7 @@ import pytest import torch -from lmdeploy.pytorch.messages import SchedulerSession +from lmdeploy.pytorch.messages import SchedulerSession, SequenceManager, SequenceMeta from lmdeploy.pytorch.paging.block_manager import DefaultBlockManager, WindowBlockManager from lmdeploy.pytorch.paging.block_manager.base_block_manager import LogicalAllocator @@ -89,8 +89,14 @@ def num_gpu_blocks(self): def block_mgr(self, num_cpu_blocks, num_gpu_blocks): yield DefaultBlockManager(num_cpu_blocks, num_gpu_blocks) - def test_alloc(self, block_mgr, block_size, num_gpu_blocks): - sess = SchedulerSession(0, block_size) + @pytest.fixture + def seq_manager(self, block_size): + seq_meta = SequenceMeta(block_size) + yield SequenceManager(seq_meta) + + def test_alloc(self, block_mgr, seq_manager, num_gpu_blocks): + sess = SchedulerSession(0, seq_manager) + block_size = sess.seq_meta.block_size # test alloc token_ids = torch.tensor([1]) @@ -113,9 +119,10 @@ def test_alloc(self, block_mgr, block_size, num_gpu_blocks): msg = sess.add_sequence(token_ids) assert not block_mgr.can_allocate(msg) - def test_num_required_blocks(self, block_mgr, block_size, num_gpu_blocks): + def test_num_required_blocks(self, block_mgr, seq_manager, num_gpu_blocks): from lmdeploy.pytorch.messages import InputEmbeddings - sess = SchedulerSession(0, block_size) + sess = SchedulerSession(0, seq_manager) + block_size = sess.seq_meta.block_size token_ids = torch.tensor([1]) msg = sess.add_sequence(token_ids) @@ -133,8 +140,9 @@ def test_num_required_blocks(self, block_mgr, block_size, num_gpu_blocks): num_required = block_mgr.num_required_blocks(msg) assert num_required == 3 - def test_append_slot(self, block_mgr, block_size, num_gpu_blocks): - sess = SchedulerSession(0, block_size) + def test_append_slot(self, block_mgr, seq_manager, num_gpu_blocks): + sess = SchedulerSession(0, seq_manager) + block_size = sess.seq_meta.block_size # test append token_ids = torch.tensor([1]) @@ -158,8 +166,9 @@ def test_append_slot(self, block_mgr, block_size, num_gpu_blocks): assert len(block_table) == 2 assert block_mgr.get_num_free_gpu_blocks() == num_gpu_blocks - 2 - def test_swap(self, block_mgr, block_size, num_gpu_blocks): - sess = SchedulerSession(0, block_size) + def test_swap(self, block_mgr, seq_manager, num_gpu_blocks): + sess = SchedulerSession(0, seq_manager) + block_size = sess.seq_meta.block_size token_ids = torch.tensor([1] * (block_size + 1)) msg = sess.add_sequence(token_ids) @@ -215,12 +224,18 @@ def num_cpu_blocks(self): def num_gpu_blocks(self): yield 4 + @pytest.fixture + def seq_manager(self, block_size): + seq_meta = SequenceMeta(block_size) + yield SequenceManager(seq_meta) + @pytest.fixture def block_mgr(self, num_cpu_blocks, num_gpu_blocks, window_size): yield WindowBlockManager(num_cpu_blocks, num_gpu_blocks, window_size) - def test_alloc(self, block_mgr, block_size, num_gpu_blocks): - sess = SchedulerSession(0, block_size) + def test_alloc(self, block_mgr, seq_manager, num_gpu_blocks): + sess = SchedulerSession(0, seq_manager) + block_size = sess.seq_meta.block_size # test alloc token_ids = torch.tensor([1]) @@ -243,8 +258,8 @@ def test_alloc(self, block_mgr, block_size, num_gpu_blocks): msg = sess.add_sequence(token_ids) assert not block_mgr.can_allocate(msg) - def test_win_alloc(self, block_mgr, block_size, num_gpu_blocks, window_size): - sess = SchedulerSession(0, block_size) + def test_win_alloc(self, block_mgr, seq_manager, num_gpu_blocks, window_size): + sess = SchedulerSession(0, seq_manager) # 2 win block token_ids = torch.tensor([1] * window_size) diff --git a/tests/pytorch/paging/test_block_trie.py b/tests/pytorch/paging/test_block_trie.py index 06829f4c75..a9a3f571e9 100644 --- a/tests/pytorch/paging/test_block_trie.py +++ b/tests/pytorch/paging/test_block_trie.py @@ -2,7 +2,7 @@ import pytest from lmdeploy.pytorch.config import CacheConfig -from lmdeploy.pytorch.messages import SchedulerSession +from lmdeploy.pytorch.messages import SchedulerSession, SequenceManager, SequenceMeta from lmdeploy.pytorch.paging.block_manager import build_block_manager from lmdeploy.pytorch.paging.block_trie import BlockTrie @@ -37,9 +37,15 @@ def block_mgr(self, cache_config): def block_trie(self, cache_config, block_mgr): yield BlockTrie(cache_config, block_mgr) - def test_allocate(self, block_trie, block_mgr, block_size): + @pytest.fixture + def seq_manager(self, block_size): + seq_meta = SequenceMeta(block_size) + yield SequenceManager(seq_meta) + + def test_allocate(self, block_trie, block_mgr, seq_manager): allocator = block_trie.allocator - sess = SchedulerSession(0, block_size) + sess = SchedulerSession(0, seq_manager) + block_size = sess.seq_meta.block_size token_ids = ([1] * block_size + [2] * block_size) token_ids += [3] * (block_size // 2) seq = sess.add_sequence(token_ids) @@ -75,9 +81,10 @@ def test_allocate(self, block_trie, block_mgr, block_size): assert node in block_trie.leaves assert len(block_trie.leaves) == 1 - def test_match(self, block_trie, block_mgr, block_size): + def test_match(self, block_trie, block_mgr, seq_manager): allocator = block_trie.allocator - sess = SchedulerSession(0, block_size) + sess = SchedulerSession(0, seq_manager) + block_size = sess.seq_meta.block_size # initialize cache token_ids = ([1] * block_size + [2] * block_size) @@ -112,9 +119,10 @@ def test_match(self, block_trie, block_mgr, block_size): ref_cnt = allocator.get_ref_count(logical_blocks.get_real_blocks()) assert np.array_equal(ref_cnt, [4, 3]) - def test_evict(self, block_trie, block_size, num_gpu_blocks): + def test_evict(self, block_trie, seq_manager, num_gpu_blocks): block_mgr = block_trie.block_manager - sess = SchedulerSession(0, block_size) + sess = SchedulerSession(0, seq_manager) + block_size = sess.seq_meta.block_size token_ids = ([1] * block_size * (num_gpu_blocks - 1)) token_ids += [2] * (block_size // 2) seq = sess.add_sequence(token_ids) From af01586d0c0ca35195ea2a1231f6e467bf51475f Mon Sep 17 00:00:00 2001 From: grimoire Date: Tue, 26 Aug 2025 21:27:46 +0800 Subject: [PATCH 02/29] block sparse attn --- .../pytorch/kernels/cuda/flashattention.py | 14 ++- .../pytorch/kernels/cuda/pagedattention.py | 52 ++++---- tests/pytorch/kernel/test_flash_attention.py | 74 ++++++++++- tests/pytorch/kernel/test_paged_attention.py | 116 ++++++++++++++---- 4 files changed, 204 insertions(+), 52 deletions(-) diff --git a/lmdeploy/pytorch/kernels/cuda/flashattention.py b/lmdeploy/pytorch/kernels/cuda/flashattention.py index b41b5394ff..428444478e 100644 --- a/lmdeploy/pytorch/kernels/cuda/flashattention.py +++ b/lmdeploy/pytorch/kernels/cuda/flashattention.py @@ -56,7 +56,8 @@ def _load_kv(ptrs, boundary_check: tl.constexpr): def _prefill_fwd_inner(acc, l_i, m_i, q, k_ptrs, v_ptrs, q1, k1_ptrs, loop_start, loop_end, sm_scale, history_mask, kv_min_loc, causal_mask: tl.constexpr, window_size: tl.constexpr, logit_softcapping: tl.constexpr, k_bound: tl.constexpr, v_bound: tl.constexpr, - shared_kv: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_DK1: tl.constexpr): + shared_kv: tl.constexpr, block_sparse_size: tl.constexpr, BLOCK_N: tl.constexpr, + BLOCK_DK1: tl.constexpr): k_ptrs = tl.advance(k_ptrs, (0, loop_start)) v_ptrs = tl.advance(v_ptrs, (loop_start, 0)) if BLOCK_DK1: @@ -77,7 +78,11 @@ def _prefill_fwd_inner(acc, l_i, m_i, q, k_ptrs, v_ptrs, q1, k1_ptrs, loop_start qk *= sm_scale qk = softcapping(qk, logit_softcapping) qk = qk * tl_log2(math.e) - qk_mask = (history_mask[:, None]) >= (start_n + offs_n[None, :]) + if block_sparse_size > 1: + offs_mask = (start_n + offs_n) // block_sparse_size * block_sparse_size + qk_mask = (history_mask[:, None]) >= offs_mask[None, :] + else: + qk_mask = (history_mask[:, None]) >= (start_n + offs_n[None, :]) if window_size > 0: qk_mask = qk_mask and ((start_n + offs_n[None, :]) >= kv_min_loc[:, None]) qk = tl.where( @@ -180,6 +185,7 @@ def _flash_prefill_fwd_kernel( window_size: tl.constexpr, logit_softcapping: tl.constexpr, shared_kv: tl.constexpr, + block_sparse_size: tl.constexpr, BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_DK: tl.constexpr, @@ -295,6 +301,7 @@ def _flash_prefill_fwd_kernel( k_bound=k_bound0, v_bound=v_bound0, shared_kv=shared_kv, + block_sparse_size=block_sparse_size, BLOCK_N=BLOCK_N, BLOCK_DK1=BLOCK_DK1) @@ -322,6 +329,7 @@ def _flash_prefill_fwd_kernel( k_bound=k_bound1, v_bound=v_bound1, shared_kv=shared_kv, + block_sparse_size=block_sparse_size, BLOCK_N=BLOCK_N, BLOCK_DK1=BLOCK_DK1) # epilogue @@ -440,6 +448,7 @@ def flash_attention_fwd( logit_softcapping: float = None, sinks: Tensor = None, causal: bool = True, + block_sparse_size: int = 1, kv_layout: str = 'hsd', ): """Varlen flash Attention forward. @@ -534,6 +543,7 @@ def grid(args): window_size=window_size, logit_softcapping=logit_softcapping, shared_kv=shared_kv, + block_sparse_size=block_sparse_size, BLOCK_DK=BLOCK_DK, BLOCK_DK1=BLOCK_DK1, BLOCK_DV=BLOCK_DV, diff --git a/lmdeploy/pytorch/kernels/cuda/pagedattention.py b/lmdeploy/pytorch/kernels/cuda/pagedattention.py index 3b5cec82f3..5b2a123d09 100644 --- a/lmdeploy/pytorch/kernels/cuda/pagedattention.py +++ b/lmdeploy/pytorch/kernels/cuda/pagedattention.py @@ -59,6 +59,7 @@ def _fwd_grouped_split_kernel( stride_od: tl.constexpr, stride_boffb, kv_group_num: tl.constexpr, + seq_len: tl.constexpr, window_size: tl.constexpr, head_size: tl.constexpr, head_size_v: tl.constexpr, @@ -74,18 +75,20 @@ def _fwd_grouped_split_kernel( ): """First step kernel of split k attention.""" cur_batch = tl.program_id(2) - cur_kv_head = tl.program_id(0) + tile_id = tl.program_id(0) split_k_id = tl.program_id(1) - if BLOCK_H < kv_group_num: - HEAD_PER_CTA: tl.constexpr = BLOCK_H - else: - HEAD_PER_CTA: tl.constexpr = kv_group_num - cur_head = cur_kv_head * HEAD_PER_CTA + tl.arange(0, BLOCK_H) - mask_h = cur_head < cur_kv_head * HEAD_PER_CTA + HEAD_PER_CTA + HEADS_PER_REQ: tl.constexpr = kv_group_num * seq_len + TILES_PER_GROUP: tl.constexpr = tl.cdiv(HEADS_PER_REQ, BLOCK_H) + subtile_id = tile_id % TILES_PER_GROUP + cur_kv_head = tile_id // TILES_PER_GROUP + offs_h = subtile_id * BLOCK_H + tl.arange(0, BLOCK_H) + cur_head = cur_kv_head * kv_group_num + offs_h % kv_group_num + cur_token = cur_batch * seq_len + offs_h // kv_group_num + + mask_h = cur_head < cur_kv_head * kv_group_num + kv_group_num + mask_h = mask_h & (cur_token < cur_batch * seq_len + seq_len) mask_h = mask_h & (cur_head < num_heads_q) - if BLOCK_H < kv_group_num: - cur_kv_head = (cur_kv_head * HEAD_PER_CTA) // kv_group_num q_seqlen = 1 kv_seqlen = tl.load(KV_seqlens + cur_batch) @@ -104,7 +107,7 @@ def _fwd_grouped_split_kernel( off_k = (cur_kv_head * stride_kh + offs_d[:, None] * stride_kd + offs_n[None, :] * stride_kbs) off_v = (cur_kv_head * stride_vh + offs_dv[None, :] * stride_vd + offs_n[:, None] * stride_vbs) - off_q = (cur_batch * stride_qbs + cur_head[:, None] * stride_qh + offs_d[None, :] * stride_qd) + off_q = (cur_token[:, None] * stride_qbs + cur_head[:, None] * stride_qh + offs_d[None, :] * stride_qd) q = tl.load(Q + off_q, mask=mask_h[:, None] & mask_d[None, :], other=0) k_ptrs = K + off_k @@ -114,7 +117,7 @@ def _fwd_grouped_split_kernel( offs_d1 = BLOCK_DMODEL + tl.arange(0, BLOCK_DMODEL1) mask_d1 = offs_d1 < head_size offs_d1 = offs_d1 % head_size - off_q1 = (cur_batch * stride_qbs + cur_head[:, None] * stride_qh + offs_d1[None, :] * stride_qd) + off_q1 = (cur_token[:, None] * stride_qbs + cur_head[:, None] * stride_qh + offs_d1[None, :] * stride_qd) q1 = tl.load(Q + off_q1, mask=mask_h[:, None] & mask_d1[None, :], other=0) off_k1 = (cur_kv_head * stride_kh + offs_d1[:, None] * stride_kd + offs_n[None, :] * stride_kbs) k1_ptrs = K + off_k1 @@ -196,11 +199,11 @@ def _fwd_grouped_split_kernel( # initialize pointers to output if loop_end > loop_start: - off_acc = (cur_batch * stride_obs + split_k_id * stride_ok + cur_head[:, None] * stride_oh + + off_acc = (cur_token[:, None] * stride_obs + split_k_id * stride_ok + cur_head[:, None] * stride_oh + offs_dv[None, :] * stride_od) tl.store(Acc_out + off_acc, acc, mask=mask_h[:, None] & mask_dv[None, :]) - off_meta = (cur_batch * stride_obs + split_k_id * stride_ok + cur_head * stride_oh + head_size_v) + off_meta = (cur_token * stride_obs + split_k_id * stride_ok + cur_head * stride_oh + head_size_v) tl.store(Acc_out + off_meta, m_i, mask=mask_h) tl.store(Acc_out + off_meta + 1, l_i, mask=mask_h) @@ -588,7 +591,9 @@ def _get_block_d(Lk): if sm_scale is None: sm_scale = 1.0 / (Lq**0.5) batch, head = kv_seqlens.shape[0], q.shape[-2] - kv_group_num = q.shape[-2] // k.shape[h_dim] + num_tokens = q.shape[-3] + num_kv_heads = k.shape[h_dim] + kv_group_num = head // num_kv_heads if sinks is not None: assert sinks.is_contiguous() @@ -601,20 +606,22 @@ def _get_block_d(Lk): 'might leads to bad performance. ' 'Please reduce `block_size`.') - is_decoding = q.shape[-3] == kv_seqlens.size(0) - assert is_decoding, 'we only support decoding paged attention.' + valid = num_tokens % batch == 0 + assert valid, 'we only support decoding paged attention.' + seq_len = num_tokens // batch BLOCK_DMODEL, BLOCK_DMODEL1, BLOCK_DV = _get_block_d(Lq) - p2_kv_group_num = triton.next_power_of_2(kv_group_num) - BLOCK_H = max(16, min(BLOCK, p2_kv_group_num)) - grid_1 = triton.cdiv(head, min(BLOCK_H, kv_group_num)) + HEADS_PER_REQ = kv_group_num * seq_len + BLOCK_H = max(16, min(BLOCK, triton.next_power_of_2(HEADS_PER_REQ))) + TILES_PER_GROUP = triton.cdiv(HEADS_PER_REQ, BLOCK_H) + grid_1 = TILES_PER_GROUP * num_kv_heads SPLIT_K = _get_split_k(q.device.index, grid_1, batch) if quant_policy != 4: - acc = q.new_empty(batch, head, SPLIT_K, Lv + 2, dtype=torch.float32) + acc = q.new_empty(num_tokens, head, SPLIT_K, Lv + 2, dtype=torch.float32) else: - acc = q.new_empty(batch, head, SPLIT_K, o.shape[-1] + 2, dtype=torch.float32) + acc = q.new_empty(num_tokens, head, SPLIT_K, o.shape[-1] + 2, dtype=torch.float32) grid = ( grid_1, @@ -704,6 +711,7 @@ def _get_block_d(Lk): stride_od=acc.stride(-1), stride_boffb=block_offsets.stride(0), kv_group_num=kv_group_num, + seq_len=seq_len, window_size=window_size, head_size=Lk, head_size_v=Lv, @@ -720,7 +728,7 @@ def _get_block_d(Lk): num_stages=num_stages) num_warps = 4 - grid = (batch, head) + grid = (num_tokens, head) if quant_policy == 4: Lv *= 2 BLOCK_DV *= 2 diff --git a/tests/pytorch/kernel/test_flash_attention.py b/tests/pytorch/kernel/test_flash_attention.py index 2b6d17aa2f..5e1ad7cf1a 100644 --- a/tests/pytorch/kernel/test_flash_attention.py +++ b/tests/pytorch/kernel/test_flash_attention.py @@ -11,16 +11,18 @@ def _conti_input(data, q_seqlens): def _make_bias(q_seqlens, history_lens, neg_val, causal): + batch_size = q_seqlens.shape[0] kv_seqlens = q_seqlens + history_lens max_seq_len = q_seqlens.max().item() max_kv_len = kv_seqlens.max().item() if causal: - seq_ranges = [torch.arange(max_seq_len) for _ in q_seqlens] - for r, l in zip(seq_ranges, q_seqlens): - r[l:] = -max_kv_len - seq_ranges = torch.stack(seq_ranges, dim=0).cuda() - kv_ranges = [torch.arange(max_kv_len) for _ in kv_seqlens] - kv_ranges = torch.stack(kv_ranges, 0).cuda() + seq_ranges = torch.arange(max_seq_len).cuda() + seq_ranges = seq_ranges.repeat(batch_size, 1) + seq_ranges = torch.where(seq_ranges < q_seqlens[:, None], seq_ranges, -max_kv_len) + + kv_ranges = torch.arange(max_kv_len).cuda() + kv_ranges = kv_ranges.repeat(batch_size, 1) + mask = (kv_ranges[:, None, :] - seq_ranges[:, :, None] > history_lens[:, None, None]) return mask.float() * neg_val else: @@ -31,6 +33,27 @@ def _make_bias(q_seqlens, history_lens, neg_val, causal): return (~mask).float() * neg_val +def _make_block_sparse_bias(q_seqlens: torch.Tensor, history_lens: torch.Tensor, neg_val: float, + block_sparse_size: int): + """Make block sparse bias.""" + batch_size = q_seqlens.shape[0] + kv_seqlens = q_seqlens + history_lens + max_seq_len = q_seqlens.max().item() + max_kv_len = kv_seqlens.max().item() + + seq_ranges = torch.arange(max_seq_len).cuda() + seq_ranges = seq_ranges // block_sparse_size * block_sparse_size + seq_ranges = seq_ranges.repeat(batch_size, 1) + seq_ranges = torch.where(seq_ranges < q_seqlens[:, None], seq_ranges, -max_kv_len) + + kv_ranges = torch.arange(max_kv_len).cuda() + kv_ranges = kv_ranges // block_sparse_size * block_sparse_size + kv_ranges = kv_ranges.repeat(batch_size, 1) + + mask = (kv_ranges[:, None, :] - seq_ranges[:, :, None] > history_lens[:, None, None]) + return mask.float() * neg_val + + def _naive_attention(batched_q, batched_kv, bias, sinks=None): batched_k, batched_v = batched_kv @@ -283,3 +306,42 @@ def test_sinks(self, conti_q, conti_kv, q_start_loc, q_seqlens, kv_start_loc, kv sinks=sinks, causal=causal) torch.testing.assert_close(out, conti_sink_gt, atol=1e-3, rtol=1e-5) + + # block sparse attention + @pytest.fixture + def block_sparse_size(self): + yield 4 + + @pytest.fixture + def block_sparse_mask(self, q_seqlens, history_lens, block_sparse_size): + neg_val = -1e30 + yield _make_block_sparse_bias(q_seqlens, history_lens, neg_val, block_sparse_size) + + @pytest.fixture + def block_sparse_gt(self, batched_q, batched_kv, block_sparse_mask): + yield _naive_attention(batched_q, batched_kv, block_sparse_mask) + + @pytest.mark.parametrize('head_dim_k', [32], indirect=True) + @pytest.mark.parametrize('head_dim_v', [32], indirect=True) + @pytest.mark.parametrize('num_heads_q', [8], indirect=True) + @pytest.mark.parametrize('num_heads_k', [2], indirect=True) + @pytest.mark.parametrize(['q_seqlens', 'history_lens'], [([16, 32], [64, 8])], indirect=True) + def test_block_sparse_attention(self, conti_q, conti_kv, q_start_loc, q_seqlens, kv_start_loc, kv_seqlens, + head_dim_v, block_sparse_size, block_sparse_gt): + from lmdeploy.pytorch.kernels.cuda.flashattention import flash_attention_fwd + max_seq_len = q_seqlens.max().item() + + conti_k, conti_v = conti_kv + out = conti_q.new_empty(*conti_q.shape[:-1], head_dim_v) + flash_attention_fwd(conti_q, + conti_k, + conti_v, + out, + q_start_loc=q_start_loc, + q_seqlens=q_seqlens, + kv_start_loc=kv_start_loc, + kv_seqlens=kv_seqlens, + max_seqlen=max_seq_len, + block_sparse_size=block_sparse_size) + gt = _conti_input(block_sparse_gt, q_seqlens) + torch.testing.assert_close(out, gt, atol=1e-3, rtol=1e-5) diff --git a/tests/pytorch/kernel/test_paged_attention.py b/tests/pytorch/kernel/test_paged_attention.py index a6289fc626..f4268cacdc 100644 --- a/tests/pytorch/kernel/test_paged_attention.py +++ b/tests/pytorch/kernel/test_paged_attention.py @@ -10,20 +10,42 @@ def _conti_input(data, seq_lens): return data -def _make_bias(seq_lens, history_lens, neg_val): - full_seq_lens = seq_lens + history_lens - max_seq_len = seq_lens.max().item() - max_full_len = full_seq_lens.max().item() - seq_ranges = [torch.arange(max_seq_len) for _ in seq_lens] - for r, l in zip(seq_ranges, seq_lens): - r[l:] = -max_full_len - seq_ranges = torch.stack(seq_ranges, dim=0).cuda() - kv_ranges = [torch.arange(max_full_len) for _ in full_seq_lens] - kv_ranges = torch.stack(kv_ranges, 0).cuda() +def _make_bias(q_seqlens, history_lens, neg_val): + batch_size = q_seqlens.shape[0] + full_seq_lens = q_seqlens + history_lens + max_seq_len = q_seqlens.max().item() + max_kv_len = full_seq_lens.max().item() + seq_ranges = torch.arange(max_seq_len).cuda() + seq_ranges = seq_ranges.repeat(batch_size, 1) + seq_ranges = torch.where(seq_ranges < q_seqlens[:, None], seq_ranges, -max_kv_len) + + kv_ranges = torch.arange(max_kv_len).cuda() + kv_ranges = kv_ranges.repeat(batch_size, 1) mask = kv_ranges[:, None, :] - seq_ranges[:, :, None] > history_lens[:, None, None] return mask.float() * neg_val +def _make_block_sparse_bias(q_seqlens: torch.Tensor, history_lens: torch.Tensor, neg_val: float, + block_sparse_size: int): + """Make block sparse bias.""" + batch_size = q_seqlens.shape[0] + kv_seqlens = q_seqlens + history_lens + max_seq_len = q_seqlens.max().item() + max_kv_len = kv_seqlens.max().item() + + seq_ranges = torch.arange(max_seq_len).cuda() + seq_ranges = seq_ranges // block_sparse_size * block_sparse_size + seq_ranges = seq_ranges.repeat(batch_size, 1) + seq_ranges = torch.where(seq_ranges < q_seqlens[:, None], seq_ranges, -max_kv_len) + + kv_ranges = torch.arange(max_kv_len).cuda() + kv_ranges = kv_ranges // block_sparse_size * block_sparse_size + kv_ranges = kv_ranges.repeat(batch_size, 1) + + mask = (kv_ranges[:, None, :] - seq_ranges[:, :, None] > history_lens[:, None, None]) + return mask.float() * neg_val + + def _make_blocked_cache(batched_k, batched_v, seq_lens, @@ -119,7 +141,7 @@ def _make_cu_seqlens(seqlens): return output -class TestPagedAttention: +class TestPagedAttentionBase: @pytest.fixture def dtype(self): @@ -154,18 +176,22 @@ def history_lens(self, request): yield torch.tensor(request.param, device='cuda') @pytest.fixture - def seq_lens(self, history_lens): - yield torch.ones_like(history_lens) + def seq_len(self): + yield 1 + + @pytest.fixture + def seq_lens(self, seq_len, history_lens): + yield torch.ones_like(history_lens) * seq_len @pytest.fixture - def kv_seqlens(self, history_lens): - yield 1 + history_lens + def kv_seqlens(self, seq_lens, history_lens): + yield seq_lens + history_lens @pytest.fixture - def batched_q(self, kv_seqlens, num_heads_q, feat_dim, dtype): + def batched_q(self, seq_len, kv_seqlens, num_heads_q, feat_dim, dtype): torch.manual_seed(123) batch_size = len(kv_seqlens) - inputs = torch.rand(batch_size, 1, num_heads_q, feat_dim, dtype=dtype, device='cuda') + inputs = torch.rand(batch_size, seq_len, num_heads_q, feat_dim, dtype=dtype, device='cuda') yield inputs @pytest.fixture @@ -178,8 +204,7 @@ def batched_kv(self, kv_seqlens, num_heads_k, feat_dim, feat_dim_v, dtype): yield k, v @pytest.fixture - def conti_q(self, kv_seqlens, batched_q): - seq_lens = torch.ones_like(kv_seqlens) + def conti_q(self, seq_lens, batched_q): yield _conti_input(batched_q, seq_lens) @pytest.fixture @@ -225,7 +250,10 @@ def gt(self, batched_q, batched_kv, mask): def conti_gt(self, gt, seq_lens): yield _conti_input(gt, seq_lens) - @pytest.mark.parametrize('feat_dim', [48, 32], indirect=True) + +class TestPagedAttention(TestPagedAttentionBase): + + @pytest.mark.parametrize('feat_dim', [32, 32], indirect=True) @pytest.mark.parametrize('feat_dim_v', [32], indirect=True) @pytest.mark.parametrize(['num_heads_q', 'num_heads_k'], [(128, 2), (8, 2), (2, 2)], indirect=True) @pytest.mark.parametrize('history_lens', [(50, 40, 30, 20)], indirect=True) @@ -288,7 +316,7 @@ def test_window_attention(self, conti_q, blocked_kv, block_offsets, history_lens torch.testing.assert_close(out, window_gt, atol=1e-3, rtol=1e-5) -class TestPagedAttentionSink(TestPagedAttention): +class TestPagedAttentionSink(TestPagedAttentionBase): @pytest.fixture def sinks(self, num_heads_q, dtype): @@ -308,7 +336,8 @@ def conti_sink_gt(self, sink_gt, seq_lens): @pytest.mark.parametrize('history_lens', [(50, 40, 30, 20)], indirect=True) @pytest.mark.parametrize('block_size', [16], indirect=True) @pytest.mark.parametrize('layout', ['bshd'], indirect=True) - def test_sink(self, conti_q, blocked_kv, block_offsets, history_lens, feat_dim_v, layout, sinks, conti_sink_gt): + def test_paged_attention(self, conti_q, blocked_kv, block_offsets, history_lens, feat_dim_v, layout, sinks, + conti_sink_gt): from lmdeploy.pytorch.kernels.cuda import paged_attention_fwd kv_seq_lens = 1 + history_lens @@ -450,3 +479,46 @@ class TestPagedAttentionInt4(TestPagedAttentionInt8): @pytest.fixture def nbits(self): yield 4 + + +class TestPagedAttentionBlockDecoding(TestPagedAttentionBase): + + @pytest.fixture + def seq_len(self): + yield 4 + + @pytest.fixture + def mask(self, seq_lens, history_lens, seq_len): + neg_val = -1e30 + yield _make_block_sparse_bias(seq_lens, history_lens, neg_val, seq_len) + + @pytest.fixture + def gt(self, batched_q, batched_kv, mask): + yield _naive_attention(batched_q, batched_kv, mask) + + @pytest.fixture + def conti_gt(self, gt, seq_lens): + yield _conti_input(gt, seq_lens) + + @pytest.mark.parametrize('feat_dim', [48, 32], indirect=True) + @pytest.mark.parametrize('feat_dim_v', [32], indirect=True) + @pytest.mark.parametrize(['num_heads_q', 'num_heads_k'], [(128, 2), (8, 2), (2, 2)], indirect=True) + @pytest.mark.parametrize('history_lens', [(52, 40, 32, 20)], indirect=True) + @pytest.mark.parametrize('block_size', [16], indirect=True) + @pytest.mark.parametrize('layout', ['bshd'], indirect=True) + def test_paged_attention(self, conti_q, blocked_kv, block_offsets, seq_lens, history_lens, feat_dim_v, layout, + conti_gt): + from lmdeploy.pytorch.kernels.cuda import paged_attention_fwd + kv_seq_lens = seq_lens + history_lens + + blocked_k, blocked_v = blocked_kv + out = conti_q.new_empty(*conti_q.shape[:-1], feat_dim_v) + + paged_attention_fwd(conti_q, + blocked_k, + blocked_v, + out, + block_offsets=block_offsets, + kv_seqlens=kv_seq_lens, + kv_layout=layout) + torch.testing.assert_close(out, conti_gt, atol=1e-3, rtol=1e-5) From e328c5dffd858d97ebdf94fbeef5a74e1b1ca96c Mon Sep 17 00:00:00 2001 From: grimoire Date: Mon, 1 Sep 2025 15:02:27 +0800 Subject: [PATCH 03/29] support SDAR --- lmdeploy/messages.py | 2 + lmdeploy/pytorch/backends/attention.py | 1 + lmdeploy/pytorch/backends/cuda/attention.py | 10 +- .../pytorch/backends/cuda/graph_runner.py | 27 +- lmdeploy/pytorch/config.py | 4 + lmdeploy/pytorch/configurations/sdar.py | 18 + lmdeploy/pytorch/engine/engine.py | 192 ++++++--- lmdeploy/pytorch/engine/logits_process.py | 112 ++++- lmdeploy/pytorch/engine/model_agent.py | 309 +++++++++---- lmdeploy/pytorch/engine/unmasking.py | 73 ++++ lmdeploy/pytorch/messages.py | 346 ++++++++++++--- lmdeploy/pytorch/model_inputs.py | 67 ++- lmdeploy/pytorch/models/module_map.py | 5 + lmdeploy/pytorch/models/patch.py | 13 +- lmdeploy/pytorch/models/sdar.py | 405 ++++++++++++++++++ lmdeploy/pytorch/models/utils/cudagraph.py | 6 +- lmdeploy/pytorch/nn/attention.py | 2 + lmdeploy/pytorch/paging/block_trie.py | 8 +- lmdeploy/pytorch/paging/scheduler.py | 7 +- 19 files changed, 1352 insertions(+), 255 deletions(-) create mode 100644 lmdeploy/pytorch/configurations/sdar.py create mode 100644 lmdeploy/pytorch/engine/unmasking.py create mode 100644 lmdeploy/pytorch/models/sdar.py diff --git a/lmdeploy/messages.py b/lmdeploy/messages.py index d3f4e5b116..5ec4024c36 100644 --- a/lmdeploy/messages.py +++ b/lmdeploy/messages.py @@ -333,6 +333,7 @@ class PytorchEngineConfig: It can be used to override the default config of the model, disable_vision_encoder (bool): Whether to disable loading vision encoder. Default to False. + block_sparse_size (int): Block size of block diffusion model. logprobs_mode (str): The mode of logprob, options: ['raw_logits', 'raw_logprobs'] """ dtype: str = 'auto' @@ -367,6 +368,7 @@ class PytorchEngineConfig: enable_metrics: bool = False hf_overrides: Optional[Dict[str, Any]] = None disable_vision_encoder: bool = False + block_sparse_size: int = 1 logprobs_mode: str = None role: EngineRole = EngineRole.Hybrid diff --git a/lmdeploy/pytorch/backends/attention.py b/lmdeploy/pytorch/backends/attention.py index 0792c0ccca..0c842e8871 100644 --- a/lmdeploy/pytorch/backends/attention.py +++ b/lmdeploy/pytorch/backends/attention.py @@ -93,6 +93,7 @@ def build( causal: bool = True, use_flash_mla: bool = False, learnable_sink: bool = False, + block_sparse_size: int = 1, **kwargs, ) -> AttentionImpl[T]: """build.""" diff --git a/lmdeploy/pytorch/backends/cuda/attention.py b/lmdeploy/pytorch/backends/cuda/attention.py index 2a9fcc5201..b5ecc3cbef 100644 --- a/lmdeploy/pytorch/backends/cuda/attention.py +++ b/lmdeploy/pytorch/backends/cuda/attention.py @@ -62,6 +62,7 @@ def __init__( sliding_window: int = None, logit_softcapping: float = None, causal: bool = True, + block_sparse_size: int = 1, **kwargs, ): super().__init__( @@ -91,6 +92,7 @@ def __init__( world_size, rank = get_tp_world_rank() self.alibi_head_offset = self.num_heads * rank self.alibi_num_heads = self.num_heads * world_size + self.block_sparse_size = block_sparse_size def forward( self, @@ -116,7 +118,7 @@ def forward( kv_flatten_size = attn_metadata.kv_flatten_size quant_policy = attn_metadata.quant_policy if attn_metadata.is_decoding: - max_q_seqlen = 1 + max_q_seqlen = self.block_sparse_size else: max_q_seqlen = query.numel() // (query.size(-1) * query.size(-2)) fill_max_q_seqlen = max_q_seqlen @@ -213,6 +215,7 @@ def forward( logit_softcapping=self.logit_softcapping, sinks=learnable_sink, causal=self.causal, + block_sparse_size=self.block_sparse_size, ) return attn_output @@ -528,9 +531,11 @@ def build( causal: bool = True, use_flash_mla: bool = False, learnable_sink: bool = False, + block_sparse_size: int = 1, **kwargs, ) -> TritonAttentionImpl: """build.""" + enable_fa3 = use_fa3 and not alibi and not learnable_sink and block_sparse_size == 1 if use_flash_mla is True: return FlashMLAImpl(num_heads, head_size, @@ -542,7 +547,7 @@ def build( logical_softcapping=logical_softcapping, causal=causal, **kwargs) - elif use_fa3 and not alibi and not learnable_sink: + elif enable_fa3: return FA3Impl(num_heads, head_size, scale=scale, @@ -563,4 +568,5 @@ def build( sliding_window=sliding_window, logical_softcapping=logical_softcapping, causal=causal, + block_sparse_size=block_sparse_size, **kwargs) diff --git a/lmdeploy/pytorch/backends/cuda/graph_runner.py b/lmdeploy/pytorch/backends/cuda/graph_runner.py index ccc413be75..a866f584b5 100644 --- a/lmdeploy/pytorch/backends/cuda/graph_runner.py +++ b/lmdeploy/pytorch/backends/cuda/graph_runner.py @@ -12,6 +12,7 @@ from lmdeploy.utils import get_logger from ..graph_runner import GraphRunner +from .attention import TritonAttentionMetadata logger = get_logger('lmdeploy') @@ -173,18 +174,30 @@ def _get_capture_tokens(self, batch_size: int): assert False, f'Unsupported batch_size={batch_size}' def get_graph_key(self, input_ids: torch.Tensor, position_ids: torch.Tensor, past_key_values: List, - attn_metadata: Any, inputs_embeds: torch.Tensor, **kwargs): + attn_metadata: TritonAttentionMetadata, inputs_embeds: torch.Tensor, **kwargs): """Get graph key.""" context = self.ctx_mgr.current_context() is_decoding = context.is_decoding - num_tokens = input_ids.numel() + batch_size = attn_metadata.q_seqlens.size(0) meta = self.get_meta() enable_microbatch = get_step_ctx_manager().current_context().enable_microbatch if meta.padding_batch_size is None: - new_num_tokens = self._get_capture_tokens(num_tokens) + batch_size = self._get_capture_tokens(batch_size) else: - new_num_tokens = self._get_capture_tokens(meta.padding_batch_size) - return (new_num_tokens, is_decoding, enable_microbatch) + batch_size = self._get_capture_tokens(meta.padding_batch_size) + return (batch_size, is_decoding, enable_microbatch) + + def _get_max_tokens(self, graph_key: tuple): + max_batches = graph_key[0] + is_decoding = graph_key[1] + assert is_decoding + model_paradigm = self.model_config.model_paradigm + if model_paradigm == 'dllm': + step_mgr = get_step_ctx_manager() + build_ctx = step_mgr.build_ctx + block_sparse_size = build_ctx.block_sparse_size + return max_batches * block_sparse_size + return max_batches def __call__(self, **kwargs): """call.""" @@ -198,10 +211,10 @@ def __call__(self, **kwargs): return self.model(**kwargs) graph_key = self.get_graph_key(**kwargs) - max_tokens = graph_key[0] + max_batches = graph_key[0] is_decoding = graph_key[1] if graph_key not in self._runner_map: - max_batches = max_tokens if is_decoding else self.max_batches + max_tokens = self._get_max_tokens(graph_key) runner = CUDASingleGraphRunner(self.model, max_batches=max_batches, max_tokens=max_tokens, diff --git a/lmdeploy/pytorch/config.py b/lmdeploy/pytorch/config.py index 05c716c6a9..6cb6d5cca7 100644 --- a/lmdeploy/pytorch/config.py +++ b/lmdeploy/pytorch/config.py @@ -200,6 +200,8 @@ class ModelConfig: cogvlm_style: bool = False custom_module_map: Dict[str, setattr] = None use_flash_mla: bool = False + model_paradigm: str = 'llm' + dllm_mask_token: int = 0 def get_head_size(self): """Get head size.""" @@ -294,6 +296,7 @@ class MiscConfig: hf_overrides: Dict[str, Any] = None disable_vision_encoder: bool = False logprobs_mode: str = None + block_sparse_size: int = 1 @classmethod def from_engine_config(cls, engine_config: PytorchEngineConfig): @@ -304,5 +307,6 @@ def from_engine_config(cls, engine_config: PytorchEngineConfig): model_format=engine_config.model_format, hf_overrides=engine_config.hf_overrides, disable_vision_encoder=engine_config.disable_vision_encoder, + block_sparse_size=engine_config.block_sparse_size, logprobs_mode=engine_config.logprobs_mode) return misc_config diff --git a/lmdeploy/pytorch/configurations/sdar.py b/lmdeploy/pytorch/configurations/sdar.py new file mode 100644 index 0000000000..fdf6353760 --- /dev/null +++ b/lmdeploy/pytorch/configurations/sdar.py @@ -0,0 +1,18 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from .default import AutoModelConfigBuilder, DefaultModelConfigBuilder + + +class SDARModelConfigBuilder(AutoModelConfigBuilder): + + @classmethod + def condition(cls, hf_config): + """config.""" + return hf_config.model_type == 'sdar' + + @classmethod + def build(cls, hf_config, model_path: str = None, **kwargs): + """build.""" + cfg = DefaultModelConfigBuilder.build(hf_config, model_path, **kwargs) + cfg.dllm_mask_token = 151669 + cfg.model_paradigm = 'dllm' + return cfg diff --git a/lmdeploy/pytorch/engine/engine.py b/lmdeploy/pytorch/engine/engine.py index 16ca8aac6a..d017519356 100644 --- a/lmdeploy/pytorch/engine/engine.py +++ b/lmdeploy/pytorch/engine/engine.py @@ -20,7 +20,7 @@ from ..adapter.adapter import AdapterManager from ..config import BackendConfig, CacheConfig, DistConfig, MiscConfig, ModelConfig, SchedulerConfig -from ..messages import MessageStatus, SchedulerSequence +from ..messages import MessageStatus, SchedulerSequence, UpdateTokenMode from ..model_inputs import ModelInputs, VisionModelInputs from ..paging import Scheduler from .base import EngineBase @@ -139,6 +139,16 @@ def _build_misc_config(engine_config: PytorchEngineConfig): return misc_config +def _build_seq_meta(cache_config: CacheConfig, model_config: ModelConfig, engine_config: PytorchEngineConfig): + from lmdeploy.pytorch.messages import SequenceMeta + + seq_meta = SequenceMeta(cache_config.block_size, + model_paradigm=model_config.model_paradigm, + block_sparse_size=engine_config.block_sparse_size, + dllm_mask_token=model_config.dllm_mask_token) + return seq_meta + + class CounterEvent: def __init__(self): @@ -376,7 +386,8 @@ def __init__(self, self.input_processor = self.executor.get_input_processor() cache_config = self.executor.cache_config self.adapter_manager = self._build_adapter_manager(adapters) - self.scheduler = Scheduler(scheduler_config, cache_config) + self.seq_meta = _build_seq_meta(cache_config, self.model_config, engine_config) + self.scheduler = Scheduler(scheduler_config, cache_config, seq_meta=self.seq_meta) # engine args self.model_path = model_path @@ -586,7 +597,7 @@ def __update_max_new_tokens(msg): max_session_len = self.max_session_len sampling_param = msg.sampling_param max_new_tokens = sampling_param.max_new_tokens - num_all_tokens = msg.num_all_ids + num_all_tokens = msg.num_valid_ids if max_new_tokens + num_all_tokens > max_session_len: logger.warning( f'session[{msg.session_id}]: num tokens is larger than max session len {max_session_len}. ' @@ -626,7 +637,7 @@ def __update_max_new_tokens(msg): req.data['token_ids'], multimodals=req.data.get('input_multimodals'), embeddings=req.data.get('input_embeddings'), - append_tokens=True, + mode=UpdateTokenMode.INPUTS, ) msg.sampling_param = sampling_param msg.status = MessageStatus.WAITING @@ -737,7 +748,7 @@ def create_model_inputs(self, messages: SeqList, is_prefill: bool): seq_length = torch.tensor(seq_length, dtype=torch.long) max_q_seqlen = seq_length.max().item() else: - max_q_seqlen = 1 + max_q_seqlen = len(token_ids[0]) seq_length = torch.full((batch_size, ), max_q_seqlen, dtype=torch.long) kv_seqlens = seq_length + history_lengths max_kv_seqlen = kv_seqlens.max().item() @@ -788,22 +799,60 @@ def create_model_inputs(self, messages: SeqList, is_prefill: bool): return model_inputs - def update_running(self, running: SeqList, next_token_ids: torch.Tensor, stopped: torch.Tensor, - model_metas: List[Dict[str, Any]]): - """Update scheduler.""" - if model_metas is None: - model_metas = [None] * len(running) + def _update_running_default(self, running: SeqList, next_token_ids: torch.Tensor, stopped: List[bool], + model_metas: List[Any], is_decoding: bool): + """Update running default.""" next_token_ids = next_token_ids.numpy() + update_mode = UpdateTokenMode.DECODE if is_decoding else UpdateTokenMode.PREFILL for token, msg, stop, model_meta in zip(next_token_ids, running, stopped, model_metas): if msg.status != MessageStatus.LOCKED: continue update_token = token # fill token - msg.update_token_ids(update_token, model_meta=model_meta) + msg.update_token_ids(update_token, model_meta=model_meta, mode=update_mode) + if stop: + msg.status = MessageStatus.TO_BE_MIGRATED if msg.preserve_cache else MessageStatus.STOPPED + + def _update_running_dllm(self, running: SeqList, next_token_ids: torch.Tensor, dllm_mask: torch.Tensor, + stopped: List[bool], model_metas: List[Any], is_decoding: bool, stop_pos: torch.Tensor): + block_sparse_size = self.seq_meta.block_sparse_size + next_token_ids = next_token_ids.view(-1, block_sparse_size).numpy() + dllm_mask = dllm_mask.view(-1, block_sparse_size).numpy() + stop_pos = stop_pos.tolist() + update_mode = UpdateTokenMode.DECODE if is_decoding else UpdateTokenMode.PREFILL + for idx, token in enumerate(next_token_ids): + msg = running[idx] + stop = stopped[idx] + model_meta = model_metas[idx] + mask = dllm_mask[idx] + if msg.status != MessageStatus.LOCKED: + continue + update_token = token + update_mask = mask + + # fill token + msg.update_token_ids(update_token, dllm_mask=update_mask, model_meta=model_meta, mode=update_mode) if stop: + msg.set_stop_pos(stop_pos[idx]) msg.status = MessageStatus.TO_BE_MIGRATED if msg.preserve_cache else MessageStatus.STOPPED + def update_running(self, running: SeqList, batched_outputs: BatchedOutputs, is_decoding: bool): + """Update scheduler.""" + next_token_ids = batched_outputs.next_token_ids + stopped = batched_outputs.stopped + stopped = stopped.tolist() + model_metas = batched_outputs.model_metas + if model_metas is None: + model_metas = [None] * len(running) + if self.model_config.model_paradigm == 'dllm': + dllm_mask = batched_outputs.dllm_mask + stop_pos = batched_outputs.stop_pos + return self._update_running_dllm(running, next_token_ids, dllm_mask, stopped, model_metas, is_decoding, + stop_pos) + else: + return self._update_running_default(running, next_token_ids, stopped, model_metas, is_decoding) + def update_running_migration(self, running: SeqList, next_token_ids: np.ndarray, stopped: torch.Tensor, model_metas: List[Dict[str, Any]]): """Update scheduler.""" @@ -815,36 +864,33 @@ def update_running_migration(self, running: SeqList, next_token_ids: np.ndarray, update_token = token # fill token - msg.update_token_ids(update_token, model_meta=model_meta) + msg.update_token_ids(update_token, model_meta=model_meta, mode=UpdateTokenMode.PREFILL) if stop: update_token = _EMPTY_TOKEN - msg.update_token_ids(update_token, model_meta=model_meta) + msg.update_token_ids(update_token, model_meta=model_meta, mode=UpdateTokenMode.PREFILL) msg.status = MessageStatus.STOPPED def _make_infer_outputs( self, batched_outputs: BatchedOutputs, running: SeqList, + is_decoding: bool, ): """Make infer output.""" new_token_timestamp = batched_outputs.new_token_timestamp - next_token_ids = batched_outputs.next_token_ids logits = batched_outputs.logits - stopped = batched_outputs.stopped - model_metas = batched_outputs.model_metas logprobs = batched_outputs.logprobs seq_length = [seq.num_token_ids for seq in running] is_run = [seq.status == MessageStatus.LOCKED for seq in running] - stopped = stopped.tolist() - self.update_running(running, next_token_ids, stopped, model_metas) + self.update_running(running, batched_outputs, is_decoding) # generate output outputs: Dict[int, InferOutput] = dict() for idx, msg in enumerate(running): if not is_run[idx]: continue - token_ids = msg.all_ids[-msg.num_new_tokens:] + token_ids = msg.generated_ids finish = msg.status == MessageStatus.STOPPED or msg.status == MessageStatus.TO_BE_MIGRATED if not finish and len(token_ids) == 0: continue @@ -878,58 +924,67 @@ def _make_forward_inputs(self, prefill: bool, enable_empty: bool = False): """Make forward inputs.""" prefill_interval = self.scheduler_config.prefill_interval - def __gather_all_ids(seqs: SeqList, sampling_inputs: SamplingInputs): - """Gather history.""" - if sampling_inputs.repetition_penalty is None and not any(sampling_inputs.logits_processors): - return None - batch = len(seqs) - max_len = max(seq.num_all_ids for seq in seqs) - pad_id = self.model_config.bos_token_id - pad_id = 0 if pad_id is None else pad_id - output = torch.full((batch, max_len), pad_id, dtype=torch.int64) - for idx, seq in enumerate(seqs): - h_len = seq.num_all_ids - if h_len == 0: - continue - h_ids = torch.from_numpy(seq.all_ids) - output[idx, -h_len:] = h_ids - return output - - def __gather_guided_input_ids(seqs: SeqList, sampling_inputs: SamplingInputs): - """Gather input ids for guided decode.""" - if not any(sampling_inputs.response_formats or ()): - return None - batch = len(seqs) - max_len = max(seq.num_new_tokens for seq in seqs) - pad_id = self.model_config.bos_token_id - pad_id = 0 if pad_id is None else pad_id - output = torch.full((batch, max_len), pad_id, dtype=torch.int64) - for idx, seq in enumerate(seqs): - h_len = seq.num_new_tokens - if h_len == 0: - continue - h_ids = torch.from_numpy(seq.all_ids[-seq.num_new_tokens:]) - output[idx, -h_len:] = h_ids - return output + def __get_prealloc(prefill_interval: int): + model_paradigm = self.model_config.model_paradigm + if model_paradigm == 'dllm': + block_size = self.cache_config.block_size + block_sparse_size = self.seq_meta.block_sparse_size + num_blocks = min(prefill_interval // 2, block_size // block_sparse_size) + return num_blocks * block_sparse_size + else: + return prefill_interval + + def __get_num_loops(is_prefill: bool, prefill_interval: int): + model_paradigm = self.model_config.model_paradigm + if is_prefill: + return 1 + if model_paradigm == 'dllm': + block_size = self.cache_config.block_size + block_sparse_size = self.seq_meta.block_sparse_size + max_num_loops = block_size // block_sparse_size * 2 + num_loops = min(prefill_interval, max_num_loops) + return num_loops + else: + return prefill_interval def __get_num_appendable_ids(seqs: SeqList): """Get num appendable ids.""" - ret = [seq.sampling_param.max_new_tokens - seq.num_new_tokens for seq in seqs] - return torch.tensor(ret) - - def __get_num_ignore_eos(seqs: SeqList): - """Get num ignore eos.""" - ret = [seq.sampling_param.min_new_tokens - seq.num_new_tokens for seq in seqs] - return torch.tensor(ret) + num_appendable = [seq.sampling_param.max_new_tokens - seq.num_new_tokens for seq in seqs] + num_appendable = torch.tensor(num_appendable) + if self.model_config.model_paradigm == 'dllm': + block_sparse_size = self.seq_meta.block_sparse_size + remain = [seq.num_valid_ids % block_sparse_size for seq in seqs] + num_appendable += torch.tensor(remain) + return num_appendable + + def __get_dllm_mask(seqs: SeqList): + """Get dllm mask.""" + if self.model_config.model_paradigm != 'dllm': + return None + dllm_masks = [seq.dllm_mask for seq in seqs] + dllm_masks = torch.as_tensor(np.concatenate(dllm_masks)) + return dllm_masks def __need_logits(seqs: SeqList): """Need logits.""" return any(seq.return_logits for seq in seqs) + def __make_sampling_inputs(seqs: SeqList): + pad_id = self.model_config.bos_token_id + pad_id = 0 if pad_id is None else pad_id + if self.model_config.model_paradigm == 'dllm': + from .logits_process import SamplingInputsDLLM + block_sparse_size = self.seq_meta.block_sparse_size + return SamplingInputsDLLM.from_sampling_params(seqs, + pad_token_id=pad_id, + block_sparse_size=block_sparse_size) + return SamplingInputs.from_sampling_params(seqs, pad_token_id=pad_id) + scheduler = self.scheduler logger.debug(f'Make forward inputs with prefill={prefill}, enable_empty={enable_empty}') - scheduler_output = scheduler.schedule(is_prefill=prefill, prealloc_size=prefill_interval) + prealloc_size = __get_prealloc(prefill_interval) + scheduler_output = scheduler.schedule(is_prefill=prefill, prealloc_size=prealloc_size) if enable_empty and len(scheduler_output.running) == 0: return None @@ -937,9 +992,9 @@ def __need_logits(seqs: SeqList): # schedule decoding if no valid prefill reqs. if prefill and len(scheduler_output.running) == 0 and self.engine_config.role != EngineRole.Prefill: prefill = False - scheduler_output = scheduler.schedule(is_prefill=prefill, prealloc_size=prefill_interval) + scheduler_output = scheduler.schedule(is_prefill=prefill, prealloc_size=prealloc_size) - num_loops = 1 if prefill else prefill_interval + num_loops = __get_num_loops(prefill, prefill_interval) running = scheduler_output.running swap_in_map = scheduler_output.swap_in_map swap_out_map = scheduler_output.swap_out_map @@ -949,12 +1004,10 @@ def __need_logits(seqs: SeqList): # create inputs inputs = self.create_model_inputs(running, prefill) - sampling_inputs = SamplingInputs.from_sampling_params(running) - all_ids = __gather_all_ids(running, sampling_inputs) - guided_input_ids = __gather_guided_input_ids(running, sampling_inputs) + sampling_inputs = __make_sampling_inputs(running) num_appendable_ids = __get_num_appendable_ids(running) - num_ignore_eos = __get_num_ignore_eos(running) return_logits = __need_logits(running) + dllm_mask = __get_dllm_mask(running) sync_long_context = inputs.input_ids.numel() > self.cache_config.max_prefill_token_num return dict( @@ -963,14 +1016,12 @@ def __need_logits(seqs: SeqList): swap_in_map=swap_in_map, swap_out_map=swap_out_map, loop_count=num_loops, - all_ids=all_ids, - guided_input_ids=guided_input_ids, sampling_inputs=sampling_inputs, num_appendable_ids=num_appendable_ids, - num_ignore_eos=num_ignore_eos, return_logits=return_logits, is_dummy=False, sync_long_context=sync_long_context, + dllm_mask=dllm_mask, ) async def _await_forward_event(self, forward_event: asyncio.Event): @@ -1138,6 +1189,7 @@ async def _async_loop_main( forward_event.set() num_loops = forward_inputs['loop_count'] + is_decoding = forward_inputs['inputs'].is_decoding running = next_running next_running = None scheduler.lock_running(running) @@ -1151,7 +1203,7 @@ async def _async_loop_main( # send output out = await self.executor.get_output_async() if out is not None: - step_outputs = self._make_infer_outputs(out, running=running) + step_outputs = self._make_infer_outputs(out, running=running, is_decoding=is_decoding) resp_que.put_nowait(step_outputs) # lock forward event diff --git a/lmdeploy/pytorch/engine/logits_process.py b/lmdeploy/pytorch/engine/logits_process.py index 7413e158f4..89be71696c 100644 --- a/lmdeploy/pytorch/engine/logits_process.py +++ b/lmdeploy/pytorch/engine/logits_process.py @@ -109,6 +109,47 @@ def _guided_sampling(response_formats: Tuple[Dict], scores: torch.Tensor, guided return scores +SeqList = List[SchedulerSequence] + + +def _gather_all_ids(pad_id: int, seqs: SeqList, sampling_inputs: 'SamplingInputs'): + """Gather history.""" + if sampling_inputs.repetition_penalty is None and not any(sampling_inputs.logits_processors): + return None + batch = len(seqs) + max_len = max(seq.num_valid_ids for seq in seqs) + output = torch.full((batch, max_len), pad_id, dtype=torch.int64) + for idx, seq in enumerate(seqs): + h_len = seq.num_valid_ids + if h_len == 0: + continue + h_ids = torch.from_numpy(seq.valid_ids) + output[idx, -h_len:] = h_ids + return output + + +def _gather_guided_input_ids(pad_id: int, seqs: SeqList, sampling_inputs: 'SamplingInputs'): + """Gather input ids for guided decode.""" + if not any(sampling_inputs.response_formats or ()): + return None + batch = len(seqs) + max_len = max(seq.num_new_tokens for seq in seqs) + output = torch.full((batch, max_len), pad_id, dtype=torch.int64) + for idx, seq in enumerate(seqs): + h_len = seq.num_new_tokens + if h_len == 0: + continue + h_ids = torch.from_numpy(seq.generated_ids) + output[idx, -h_len:] = h_ids + return output + + +def _get_num_ignore_eos(seqs: SeqList): + """Get num ignore eos.""" + ret = [seq.sampling_param.min_new_tokens - seq.num_new_tokens for seq in seqs] + return torch.tensor(ret) + + @dataclass class SamplingInputs: temperature: torch.Tensor = None @@ -120,16 +161,19 @@ class SamplingInputs: top_k: torch.LongTensor = None top_p: torch.Tensor = None min_p: torch.Tensor = None - random_seeds: int = None - random_offsets: int = None + random_seeds: torch.Tensor = None + random_offsets: torch.Tensor = None max_top_k: int = 1 min_top_p: float = 1.0 response_formats: Tuple[str] = () logits_processors: List[List[LogitsProcessor]] = None max_num_logprobs: Optional[int] = None + all_ids: Optional[torch.Tensor] = None + guided_input_ids: Optional[torch.Tensor] = None + num_ignore_eos: torch.Tensor = None @classmethod - def from_sampling_params(cls, seqs: List[SchedulerSequence]): + def from_sampling_params(cls, seqs: List[SchedulerSequence], pad_token_id: int = 0): """From samplingg params.""" batch_size = len(seqs) temperature = [None] * batch_size @@ -154,7 +198,7 @@ def __gather_params(): top_k[idx] = param.top_k top_p[idx] = param.top_p min_p[idx] = param.min_p - random_offsets[idx] = seq.num_all_ids + random_offsets[idx] = seq.num_valid_ids response_formats[idx] = param.response_format if param.random_seed is not None: random_seeds[idx] = param.random_seed & 0xffffffff @@ -255,6 +299,10 @@ def __get_bad_words(bad_words): logits_processors=logits_processors, max_num_logprobs=max_num_logprobs, ) + + sampling_input.all_ids = _gather_all_ids(pad_token_id, seqs, sampling_input) + sampling_input.guided_input_ids = _gather_guided_input_ids(pad_token_id, seqs, sampling_input) + sampling_input.num_ignore_eos = _get_num_ignore_eos(seqs) return sampling_input def to_device(self, device: str, non_blocking: bool = False): @@ -269,6 +317,62 @@ def to_device(self, device: str, non_blocking: bool = False): return SamplingInputs(**out_dict) + def step(self, next_token_ids: torch.Tensor, **kwargs): + """To next step.""" + self.num_ignore_eos = self.num_ignore_eos - 1 + if self.all_ids is not None: + self.all_ids = torch.cat([self.all_ids, next_token_ids[:, None]], 1) + if self.guided_input_ids is not None: + self.guided_input_ids = torch.cat([self.guided_input_ids, next_token_ids[:, None]], 1) + + +@dataclass +class SamplingInputsDLLM(SamplingInputs): + + @classmethod + def from_sampling_params(cls, seqs: List[SchedulerSequence], pad_token_id: int = 0, block_sparse_size: int = 1): + """From samplingg params.""" + out = super().from_sampling_params(seqs, pad_token_id) + if block_sparse_size == 1: + return out + + # repeat tensor + update_attr_names = [ + 'temperature', + 'bad_words', + 'bad_mask', + 'stop_words', + 'stop_mask', + 'repetition_penalty', + 'top_k', + 'top_p', + 'min_p', + 'random_seeds', + 'random_offsets', + 'all_ids', + 'guided_input_ids', + 'num_ignore_eos', + ] + for name in update_attr_names: + attr = getattr(out, name) + if attr is None: + continue + repeats = (block_sparse_size, ) + (1, ) * (attr.dim()) + attr = attr[None].repeat(*repeats).flatten(0, 1) + setattr(out, name, attr) + + if len(out.response_formats) > 0: + new_resp_formats = [] + for resp in out.response_formats: + new_resp_formats += [resp] * block_sparse_size + out.response_formats = new_resp_formats + + return out + + def step(self, next_token_ids: torch.Tensor, **kwargs): + """To next step.""" + self.num_ignore_eos = self.num_ignore_eos - 1 + def _apply_custom_logits_processors(batched_logits_processors, all_ids, logits): """Apply custom logits processors.""" diff --git a/lmdeploy/pytorch/engine/model_agent.py b/lmdeploy/pytorch/engine/model_agent.py index e601cccdfb..7c6db80fce 100644 --- a/lmdeploy/pytorch/engine/model_agent.py +++ b/lmdeploy/pytorch/engine/model_agent.py @@ -28,6 +28,7 @@ from ..weight_loader.model_weight_loader import load_model_weights from .cache_engine import CacheEngine from .logits_process import FusedLogitsProcessor, SamplingInputs +from .unmasking import UnmaskingMeta, UnmaskingProcessor logger = get_logger('lmdeploy') @@ -62,10 +63,12 @@ def to_tensor(self): class BatchedOutputs: next_token_ids: torch.Tensor stopped: torch.Tensor + stop_pos: Optional[torch.Tensor] = None logits: Optional[torch.Tensor] = None model_metas: List[Dict[str, Any]] = None logprobs: Optional[BatchedLogProbs] = None new_token_timestamp: int = 0 + dllm_mask: Optional[torch.Tensor] = None def to_cpu(self): """To cpu.""" @@ -198,6 +201,8 @@ def msg_with_rank(rank: int, msg: str): def cache_swapping(cache_engine: CacheEngine, swap_in_map: dict, swap_out_map: dict): """Perform cache swapping.""" issued_cache_op = False + swap_in_map = swap_in_map or dict() + swap_out_map = swap_out_map or dict() if len(swap_in_map) > 0: cache_engine.swap_in(swap_in_map) issued_cache_op = True @@ -241,17 +246,48 @@ def model_forward( return dict(hidden_states=output, model_metas=model_metas) -@record_function('stopping_criteria') -def _batch_stopping_criteria(token_ids: torch.Tensor, stop_words: torch.Tensor, num_appendable_ids: torch.Tensor): +def _batch_stopping_criteria_default(token_ids: torch.Tensor, stop_words: torch.Tensor, + num_appendable_ids: torch.Tensor): """Batched stopping criteria.""" num_appendable_ids = num_appendable_ids - 1 stopped = num_appendable_ids <= 0 if stop_words is not None: sw_stopped = (token_ids[:, None] == stop_words).any(1) + stop_pos = torch.zeros_like(num_appendable_ids) stopped = stopped | sw_stopped one_ids = torch.clamp_max(num_appendable_ids, 0) num_appendable_ids = torch.where(sw_stopped, one_ids, num_appendable_ids) - return stopped, num_appendable_ids + return stopped, stop_pos, num_appendable_ids + + +DLLM_MASKED = 0 +DLLM_UNMASKED = 1 +DLLM_CACHED = 2 + + +def _batch_stopping_criteria_dllm(token_ids: torch.Tensor, stop_words: torch.Tensor, num_appendable_ids: torch.Tensor, + dllm_mask: torch.Tensor): + """Batched stopping criteria.""" + num_tokens = token_ids.size(0) + batch_size = num_appendable_ids.size(0) + block_sparse_size = num_tokens // batch_size + if block_sparse_size == 1: + return _batch_stopping_criteria_default(token_ids, stop_words, num_appendable_ids) + + dllm_mask = dllm_mask.view(batch_size, block_sparse_size) + is_unmasked = (dllm_mask == DLLM_UNMASKED).all(dim=1) + num_appendable_ids -= is_unmasked * block_sparse_size + stopped = num_appendable_ids <= 0 + if stop_words is not None: + sw_stopped = (token_ids[:, None] == stop_words).any(1) + sw_stopped = sw_stopped.view(batch_size, block_sparse_size) + stop_pos = sw_stopped.int().argmax(1) + sw_stopped = sw_stopped.any(dim=1) + sw_stopped = sw_stopped & is_unmasked + stopped = stopped | sw_stopped + one_ids = torch.clamp_max(num_appendable_ids, 0) + num_appendable_ids = torch.where(sw_stopped, one_ids, num_appendable_ids) + return stopped, stop_pos, num_appendable_ids def _try_to_cuda(val, non_blocking: bool = False): @@ -363,6 +399,16 @@ def __init__(self, self.enable_microbatch_decode_batchsize_threshold = \ int(getenv('ENABLE_MICROBATCH_DECODE_BATCHSIZE_THRESHOLD', 2)) + # dllm + self.unmasking_processor = self._build_unmasking_processor() + + def _build_unmasking_processor(self): + """Build unmasking processor.""" + strategy = 'low_confidence_static' if self.model_config.model_paradigm == 'dllm' else None + unmasking_processor = UnmaskingProcessor( + UnmaskingMeta(strategy=strategy, block_sparse_size=self.misc_config.block_sparse_size, topk=2)) + return unmasking_processor + @contextmanager def all_context(self): device_mgr = get_device_manager() @@ -396,8 +442,9 @@ def warmup(self): inputs = ModelInputs.make_dummy(max_batches, is_decoding=False, device='cuda', - vocab_size=self.model_config.vocab_size) - self._forward_impl(inputs, swap_in_map=dict(), swap_out_map=dict()) + vocab_size=self.model_config.vocab_size, + build_ctx=self.build_model_ctx) + self._forward_impl(inputs) # warmup decoding(with cuda graph) capture_batch_sizes = self.patched_model.get_capture_batch_sizes() @@ -406,20 +453,92 @@ def warmup(self): inputs = ModelInputs.make_dummy(num_tokens, is_decoding=True, device='cuda', - vocab_size=self.model_config.vocab_size) - self._forward_impl(inputs, swap_in_map=dict(), swap_out_map=dict()) + vocab_size=self.model_config.vocab_size, + build_ctx=self.build_model_ctx) + self._forward_impl(inputs) + + def _slice_outs(self, inputs: torch.Tensor, seq_length: torch.LongTensor): + """Slice outputs.""" + if self.model_config.model_paradigm == 'dllm': + block_sparse_size = self.misc_config.block_sparse_size + if len(seq_length) * block_sparse_size == inputs.size(0): + return inputs + last_idx = seq_length.cumsum(0) + block_range = torch.arange(-block_sparse_size, 0, device=last_idx.device) + index = (last_idx[:, None] + block_range[None, :]).flatten() + inputs = inputs[index] + return inputs + else: + if len(seq_length) == inputs.size(0): + return inputs + last_idx = seq_length.cumsum(-1) - 1 + return inputs[last_idx] + + def _postprocess_forward_output_dllm(self, output: dict, inputs: ModelInputs, is_long_context: bool, + return_logits: bool): + """Post process for dllm.""" + block_sparse_size = self.misc_config.block_sparse_size + hidden_states = output['hidden_states'] + seq_length = inputs.seq_length + if is_long_context: + if not return_logits: + hidden_states = hidden_states[:, -block_sparse_size:] + else: + hidden_states = hidden_states.to('cuda') + else: + is_decoding = seq_length.numel() * block_sparse_size == hidden_states.size(1) + if not return_logits and not is_decoding: + hidden_states = self._slice_outs(hidden_states[0], seq_length)[None] + output['hidden_states'] = hidden_states + return output + + def _postprocess_forward_output_default(self, output: dict, inputs: ModelInputs, is_long_context: bool, + return_logits: bool): + """Post process forward output default.""" + hidden_states = output['hidden_states'] + seq_length = inputs.seq_length + if not is_long_context: + is_decoding = seq_length.numel() != hidden_states.size(1) + if not return_logits and not is_decoding: + hidden_states = self._slice_outs(hidden_states[0], seq_length)[None] + else: + if not return_logits: + last_token_loc = [-1] + hidden_states = hidden_states[:, last_token_loc] + else: + hidden_states = hidden_states.to('cuda') + output['hidden_states'] = hidden_states + return output + + def _postprocess_forward_output(self, output: dict, inputs: ModelInputs, is_long_context: bool, + return_logits: bool): + """Post process forward output.""" + if self.model_config.model_paradigm == 'dllm': + return self._postprocess_forward_output_dllm(output, inputs, is_long_context, return_logits) + else: + return self._postprocess_forward_output_default(output, inputs, is_long_context, return_logits) + + @record_function('stopping_criteria') + def _batch_stopping_criteria(self, + token_ids: torch.Tensor, + stop_words: torch.Tensor, + num_appendable_ids: torch.Tensor, + dllm_mask: Optional[torch.Tensor] = None): + """Batched stopping criteria.""" + if self.model_config.model_paradigm == 'dllm': + assert dllm_mask is not None + return _batch_stopping_criteria_dllm(token_ids, stop_words, num_appendable_ids, dllm_mask) + + return _batch_stopping_criteria_default(token_ids, stop_words, num_appendable_ids) async def _async_model_forward( self, inputs: ModelInputs, - swap_in_map: Dict, - swap_out_map: Dict, return_logits: bool, sync_long_context: bool, ): """Model forward.""" max_prefill_token_num = self.cache_config.max_prefill_token_num - swap_done = False class _OutputGather: """Output gather.""" @@ -453,14 +572,7 @@ def get_output(self): torch.cuda.synchronize() return self._output - async def __forward(inputs): - """forward.""" - nonlocal swap_done, swap_in_map, swap_out_map - if swap_done: - return await self.async_forward(inputs, swap_in_map=dict(), swap_out_map=dict()) - else: - swap_done = True - return await self.async_forward(inputs, swap_in_map=swap_in_map, swap_out_map=swap_out_map) + __forward = self.async_forward async def __long_context_single_forward(new_inputs, max_seqlen: int): """One large sequence.""" @@ -501,20 +613,18 @@ async def __long_context_single_forward(new_inputs, max_seqlen: int): if not is_long_context: ret = await __forward(inputs) - if not return_logits and not inputs.is_decoding: - last_token_loc = inputs.seq_length.cumsum(0) - 1 - ret['hidden_states'] = ret['hidden_states'][:, last_token_loc] else: ret = await __long_context_single_forward(inputs, max_seqlen) - if not return_logits: - last_token_loc = [-1] - ret['hidden_states'] = ret['hidden_states'][:, last_token_loc] - else: - ret['hidden_states'] = ret['hidden_states'].to('cuda') + + ret = self._postprocess_forward_output(ret, inputs, is_long_context, return_logits) # compute dummy loop if dummy_loop > 0: - dummy_inputs = ModelInputs.make_dummy(1, False, 'cuda', vocab_size=self.model_config.vocab_size) + dummy_inputs = ModelInputs.make_dummy(1, + False, + 'cuda', + vocab_size=self.model_config.vocab_size, + build_ctx=self.build_model_ctx) for _ in range(dummy_loop): await __forward(dummy_inputs) @@ -523,29 +633,22 @@ async def __long_context_single_forward(new_inputs, max_seqlen: int): ret['logits'] = logits return ret - async def async_sampling_logits(self, logits: torch.Tensor, all_ids: torch.Tensor, guided_input_ids: torch.Tensor, - sampling_inputs: SamplingInputs, inputs: ModelInputs, ignore_eos: torch.Tensor): + async def async_sampling_logits(self, logits: torch.Tensor, sampling_inputs: SamplingInputs, inputs: ModelInputs): """Sampling logits.""" - def __get_last_logits(): - """Get last logits.""" - seq_length = inputs.seq_length - if len(seq_length) == logits.size(0): - return logits - - last_idx = seq_length.cumsum(-1) - 1 - return logits[last_idx, :] - # record function does not support async function # so we can not decorate it on async_sampling_logits with record_function('sampling_logits'): - split_logits = __get_last_logits() + all_ids = sampling_inputs.all_ids + guided_input_ids = sampling_inputs.guided_input_ids + ignore_eos = sampling_inputs.num_ignore_eos > 0 logits_processor = FusedLogitsProcessor(sampling_inputs, ignore_eos, self.tokenizer, sampling_vocab_size=self.sampling_vocab_size, logprobs_mode=self.misc_config.logprobs_mode) - logits, raw_logprobs = await logits_processor(all_ids, guided_input_ids, split_logits) + origin_logits = logits + logits, raw_logprobs = await logits_processor(all_ids, guided_input_ids, origin_logits) next_token_ids = logits_processor.sampling(logits) logprobs = logits_processor.compute_logprobs(raw_logprobs, next_token_ids) if logprobs is not None: @@ -568,11 +671,20 @@ def _broadcast_next_token(self, next_token_ids: torch.Tensor, dist_ctx: DistCont if self.cache_config.role == EngineRole.Decode: next_token_ids = next_token_ids.cpu() tp_cpu_group = dist_ctx.tp_cpu_group - dist.all_reduce(next_token_ids, op=dist.ReduceOp.SUM, group=tp_cpu_group) + handle = dist.all_reduce(next_token_ids, op=dist.ReduceOp.SUM, group=tp_cpu_group, async_op=True) else: tp_gpu_group = dist_ctx.tp_gpu_group - dist.broadcast(next_token_ids, src=0, group=tp_gpu_group) - return next_token_ids + handle = dist.broadcast(next_token_ids, src=0, group=tp_gpu_group, async_op=True) + return handle + + def dllm_unmasking(self, inputs: ModelInputs, logits: torch.Tensor, next_token_ids: torch.LongTensor, + dllm_mask: torch.Tensor): + """Unmasking dllm.""" + if self.build_model_ctx.model_paradigm != 'dllm': + return None, next_token_ids + input_ids = inputs.input_ids + input_ids = self._slice_outs(input_ids.flatten(), inputs.seq_length) + return self.unmasking_processor(logits, input_ids, next_token_ids, dllm_mask) async def _async_step_background( self, @@ -580,36 +692,47 @@ async def _async_step_background( loop_count: int, swap_in_map: Dict = None, swap_out_map: Dict = None, - all_ids: torch.Tensor = None, - guided_input_ids: torch.Tensor = None, sampling_inputs: SamplingInputs = None, num_appendable_ids: torch.LongTensor = None, - num_ignore_eos: torch.LongTensor = None, return_logits: bool = False, is_dummy: bool = False, sync_long_context: bool = False, + dllm_mask: torch.Tensor = None, ): """Asyc forward task.""" - if swap_in_map is None: - swap_in_map = dict() + dist_ctx = get_dist_manager().current_context() - if swap_out_map is None: - swap_out_map = dict() + def __update_dllm(next_token_ids: torch.Tensor, dllm_mask: torch.Tensor, seqlens: torch.Tensor): + """Update token_ids and dllm_mask.""" + model_paradigm = self.model_config.model_paradigm + if model_paradigm != 'dllm': + return next_token_ids, dllm_mask, seqlens - dist_ctx = get_dist_manager().current_context() + dllm_mask_token = self.model_config.dllm_mask_token + block_sparse_size = self.build_model_ctx.block_sparse_size + + # reshape to (batch, block_sparse_size) + next_token_ids = next_token_ids.view(-1, block_sparse_size).clone() + dllm_mask = dllm_mask.view(-1, block_sparse_size).clone() + + # flags + is_cached = (dllm_mask == DLLM_CACHED).all(dim=1) + + is_masked = (dllm_mask == DLLM_MASKED) + next_token_ids[is_cached[:, None] | is_masked] = dllm_mask_token + dllm_mask[is_cached] = DLLM_MASKED + seqlens = torch.where(is_cached.view(-1), seqlens, seqlens.new_zeros((1, ))) + + return next_token_ids.flatten(), dllm_mask.flatten(), seqlens @record_function('update_inputs_for_next_step') - def __update_inputs(next_token_ids, model_metas): + def __update_inputs(next_token_ids, model_metas, dllm_mask): """Update inputs.""" - nonlocal all_ids, guided_input_ids, swap_in_map, swap_out_map - swap_in_map = dict() - swap_out_map = dict() inputs.model_metas = model_metas - inputs.update(next_token_ids) - if all_ids is not None: - all_ids = torch.cat([all_ids, next_token_ids[:, None].to(all_ids.device)], 1) - if guided_input_ids is not None: - guided_input_ids = torch.cat([guided_input_ids, next_token_ids[:, None].to(guided_input_ids.device)], 1) + next_token_ids, dllm_mask, step_seqlens = __update_dllm(next_token_ids, dllm_mask, inputs.seq_length) + inputs.step(next_token_ids, step_seqlens) + sampling_inputs.step(next_token_ids) + return next_token_ids, dllm_mask @asynccontextmanager async def __prepare_dp(): @@ -693,46 +816,63 @@ async def __prepare_dp(): logger.debug(f' rank[{rank}]: all inputs are dummy, skip forward.') return + cache_swapping(self.cache_engine, swap_in_map=swap_in_map, swap_out_map=swap_out_map) for idx in range(loop_count): # inference logger.debug(f' rank[{rank}]: model forward [{idx}].') output = await self._async_model_forward( inputs, - swap_in_map=swap_in_map, - swap_out_map=swap_out_map, return_logits=return_logits, sync_long_context=sync_long_context, ) logits = output['logits'] logits = logits[0] # [bs, seq, prob] -> [seq, prob] + last_logits = self._slice_outs(logits, inputs.seq_length) + if dllm_mask is not None: + dllm_mask = self._slice_outs(dllm_mask, inputs.seq_length) # output empty for dummy inputs if is_dummy: continue + need_broadcast_next = (dp == 1 and tp > 1 and idx < loop_count - 1) + # sampling and stopping if need_output: logger.debug(f' rank[{rank}]: Sampling [{idx}].') # sampling - next_token_ids, logprobs = await self.async_sampling_logits(logits, all_ids, guided_input_ids, - sampling_inputs, inputs, num_ignore_eos > 0) - num_ignore_eos = num_ignore_eos - 1 + next_token_ids, logprobs = await self.async_sampling_logits(last_logits, sampling_inputs, inputs) + + # broadcast next token for TP > 1 + if need_broadcast_next: + logger.debug(f' rank[{rank}]: synchornize token ids [{idx}]') + handle = self._broadcast_next_token(next_token_ids, dist_ctx) + + # unmasking + dllm_mask, next_token_ids = self.dllm_unmasking(inputs, last_logits, next_token_ids, dllm_mask) # stopping criteria - stopped, num_appendable_ids = _batch_stopping_criteria(next_token_ids, sampling_inputs.stop_words, - num_appendable_ids) + stopped, stop_pos, num_appendable_ids = self._batch_stopping_criteria(next_token_ids, + sampling_inputs.stop_words, + num_appendable_ids, + dllm_mask=dllm_mask) + if need_broadcast_next: + handle.wait() else: # Avoid adding the ADInplaceOrView dispatch key to `next_token_ids`, # as it can trigger recompilation on different ranks when using torch.compile. with torch.inference_mode(): - next_token_ids = torch.zeros_like(num_ignore_eos) + next_token_ids = num_appendable_ids.new_zeros(last_logits.size(0)) logprobs = None - # broadcast next token for TP > 1 - need_broadcast_next = (dp == 1 and tp > 1 and idx < loop_count - 1) - if need_broadcast_next: - logger.debug(f' rank[{rank}]: synchornize token ids [{idx}]') - next_token_ids = self._broadcast_next_token(next_token_ids, dist_ctx) + # broadcast next token for TP > 1 + if need_broadcast_next: + logger.debug(f' rank[{rank}]: synchornize token ids [{idx}]') + handle = self._broadcast_next_token(next_token_ids, dist_ctx) + handle.wait() + + # unmasking + dllm_mask, next_token_ids = self.dllm_unmasking(inputs, last_logits, next_token_ids, dllm_mask) # send output model_metas = output.get('model_metas') @@ -742,12 +882,14 @@ async def __prepare_dp(): BatchedOutputs(next_token_ids=next_token_ids, logits=logits if return_logits else None, stopped=stopped, + stop_pos=stop_pos, model_metas=model_metas, - logprobs=logprobs)) + logprobs=logprobs, + dllm_mask=dllm_mask)) # update for next loop if is_decoding and idx < loop_count - 1: - __update_inputs(next_token_ids, model_metas) + next_token_ids, dllm_mask = __update_inputs(next_token_ids, model_metas, dllm_mask) async def _async_loop_background(self, forward_event: asyncio.Event = None): """Async loop background.""" @@ -775,7 +917,7 @@ async def _async_loop_background(self, forward_event: asyncio.Event = None): async def _async_loop_inputs_preprocess(self): """Async loop inputs preprocess.""" non_blocking = True - keys = ['inputs', 'all_ids', 'guided_input_ids', 'sampling_inputs', 'num_appendable_ids', 'num_ignore_eos'] + keys = ['inputs', 'sampling_inputs', 'num_appendable_ids', 'dllm_mask'] while True: forward_inputs = await self._pre_in_que.get() @@ -911,7 +1053,9 @@ def _build_model(self): if custom_module_map is not None: update_custom_module_map(custom_module_map) logger.debug(msg_with_rank(rank, 'build model.')) - build_model_ctx = BuildModelContext(disable_vision_encoder=self.misc_config.disable_vision_encoder) + build_model_ctx = BuildModelContext(disable_vision_encoder=self.misc_config.disable_vision_encoder, + block_sparse_size=self.misc_config.block_sparse_size, + model_paradigm=self.model_config.model_paradigm) patched_model = build_patched_model(self.model_config, device=device, model_format=self.misc_config.model_format, @@ -923,6 +1067,7 @@ def _build_model(self): logger.debug(msg_with_rank(rank, 'loading adapters.')) add_adapters(patched_model, adapters, dtype=self.model_config.dtype, device=device) self.patched_model = patched_model + self.build_model_ctx = build_model_ctx def build_model(self): """Build model api.""" @@ -953,8 +1098,7 @@ def build_cache_engine(self): world_size=tp, cache_stream=self.cache_stream) - def _forward_impl(self, inputs: ModelInputs, swap_in_map: SwapMap, swap_out_map: SwapMap): - cache_swapping(self.cache_engine, swap_in_map=swap_in_map, swap_out_map=swap_out_map) + def _forward_impl(self, inputs: ModelInputs): output = model_forward( self.patched_model, inputs, @@ -963,7 +1107,7 @@ def _forward_impl(self, inputs: ModelInputs, swap_in_map: SwapMap, swap_out_map: ) return output - async def async_forward(self, inputs: ModelInputs, swap_in_map: SwapMap, swap_out_map: SwapMap): + async def async_forward(self, inputs: ModelInputs): """Model forward. Args: @@ -971,7 +1115,7 @@ async def async_forward(self, inputs: ModelInputs, swap_in_map: SwapMap, swap_ou swap_in_map (SwapMap): Cache maps to swap in. swap_out_map (SwapMap): Cache maps to swap out. """ - output = self._forward_impl(inputs, swap_in_map=swap_in_map, swap_out_map=swap_out_map) + output = self._forward_impl(inputs) await asyncio.sleep(0) return output @@ -1097,7 +1241,8 @@ def _make_dummy_forward_inputs(self): model_inputs = ModelInputs.make_dummy(batch_size, is_decoding, device=self.device, - vocab_size=self.model_config.vocab_size) + vocab_size=self.model_config.vocab_size, + build_ctx=self.model_agent.build_model_ctx) forward_inputs = dict( inputs=model_inputs, loop_count=loop_count, diff --git a/lmdeploy/pytorch/engine/unmasking.py b/lmdeploy/pytorch/engine/unmasking.py new file mode 100644 index 0000000000..def33dd0d3 --- /dev/null +++ b/lmdeploy/pytorch/engine/unmasking.py @@ -0,0 +1,73 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import dataclasses +from typing import Optional + +import torch + + +@dataclasses.dataclass +class UnmaskingMeta: + strategy: Optional[str] + block_sparse_size: int + topk: int = 1 + threshold: float = 0 + + +DLLM_MASKED = 0 +DLLM_UNMASKED = 1 +DLLM_CACHED = 2 + + +class UnmaskingProcessor: + + def __init__(self, meta: UnmaskingMeta): + self.meta = meta + + def _get_scores(self, logits: torch.Tensor, token_ids: torch.Tensor): + """Get scores.""" + scores = logits.softmax(dim=-1) + scores = scores.gather(-1, token_ids.unsqueeze(-1)).flatten() + return scores + + def low_confidence_static(self, logits: torch.Tensor, token_ids: torch.Tensor, dllm_mask: torch.Tensor): + """static.""" + block_sparse_size = self.meta.block_sparse_size + topk = min(self.meta.topk, block_sparse_size) + scores = self._get_scores(logits, token_ids) + is_masked = dllm_mask == DLLM_MASKED + scores = torch.where(is_masked, scores, scores.new_zeros((1, ))) + + scores = scores.view(-1, block_sparse_size) + dllm_mask = dllm_mask.view(-1, block_sparse_size) + _, indices = scores.topk(topk, dim=-1) + dllm_unmasked = dllm_mask.scatter(-1, indices, DLLM_UNMASKED) + + is_masked = is_masked.view_as(dllm_mask) + dllm_mask = torch.where(is_masked, dllm_unmasked, dllm_mask) + return dllm_mask.flatten() + + def __call__(self, logits: torch.Tensor, input_ids: torch.Tensor, token_ids: torch.Tensor, dllm_mask: torch.Tensor): + """call.""" + strategy = self.meta.strategy + if strategy is None: + return dllm_mask + + # reshape to [num_blocks, block_sparse_size] + block_sparse_size = self.meta.block_sparse_size + dllm_mask = dllm_mask.unflatten(0, (-1, block_sparse_size)) + + is_same = (dllm_mask == dllm_mask[:, :1]).all(dim=1) + first_mask = dllm_mask[:, 0] + + # unmasked to cache + is_block_unmasked = is_same & (first_mask == DLLM_UNMASKED) + dllm_mask[is_block_unmasked] = DLLM_CACHED + + dllm_mask = dllm_mask.flatten() + token_ids = torch.where(dllm_mask != DLLM_MASKED, input_ids, token_ids) + if strategy == 'low_confidence_static': + dllm_mask = self.low_confidence_static(logits, token_ids, dllm_mask) + else: + raise RuntimeError(f'strategy {strategy} not supported.') + + return dllm_mask, token_ids diff --git a/lmdeploy/pytorch/messages.py b/lmdeploy/pytorch/messages.py index 027fd6839a..6c4ec8bc23 100644 --- a/lmdeploy/pytorch/messages.py +++ b/lmdeploy/pytorch/messages.py @@ -160,7 +160,9 @@ class MessageStatus(enum.Enum): class SequenceMeta: """Meta data shared by all sequence.""" block_size: int - model_type: str = 'llm' + model_paradigm: str = 'llm' + block_sparse_size: int = 1 + dllm_mask_token: int = 151669 class SequenceManager: @@ -222,6 +224,23 @@ def update_sequence_status(self, seq: 'SchedulerSequence', new_status: MessageSt new_status_map[seq_id] = seq +DLLM_MASKED = 0 +DLLM_UNMASKED = 1 +DLLM_CACHED = 2 +DLLM_MASK_DTYPE = np.uint8 + + +def _to_ndarray(token_ids) -> np.ndarray: + """To ndarray.""" + if isinstance(token_ids, Tensor): + token_ids = token_ids.numpy() + elif not isinstance(token_ids, np.ndarray): + token_ids = np.array(token_ids) + if token_ids.ndim == 0: + token_ids = token_ids[None] + return token_ids + + class SchedulerSession: """Scheduler session.""" @@ -242,30 +261,28 @@ def add_sequence(self, resp_cache: bool = False, preserve_cache: bool = False) -> 'SchedulerSequence': """Add a new message.""" - if isinstance(token_ids, Tensor): - token_ids = token_ids.numpy() - elif not isinstance(token_ids, np.ndarray): - token_ids = np.array(token_ids) - if token_ids.ndim == 0: - token_ids = token_ids.unsqueeze(0) if sampling_param is None: sampling_param = SamplingParam() seq_id = self.seq_manager._new_seq_id() - seq = SchedulerSequence( + seq_cls = SEQ_CLS_MAP[self.seq_meta.model_paradigm] + seq = seq_cls( seq_id=seq_id, session=self, - history_cache=HistoryTokenIds(token_ids), num_new_tokens=0, sampling_param=sampling_param, adapter_name=adapter_name, arrive_time=time.perf_counter(), - history_embeddings=HistoryEmbeddings(input_embeddings), - history_multimodals=HistoryMultiModals(multimodals), migration_request=migration_request, resp_cache=resp_cache, preserve_cache=preserve_cache, ) + seq.update_token_ids( + token_ids, + multimodals=multimodals, + embeddings=input_embeddings, + mode=UpdateTokenMode.INPUTS, + ) self.sequences[seq.seq_id] = seq self.seq_manager.add_sequence(seq) return seq @@ -338,9 +355,9 @@ class HistoryTokenIds: """History token ids.""" ALLOC_SIZE = 512 - def __init__(self, token_ids: np.ndarray = None): + def __init__(self, token_ids: np.ndarray = None, dtype: np.dtype = np.int64): if token_ids is None: - self._token_ids = np.empty((self.ALLOC_SIZE, ), dtype=np.int64) + self._token_ids = np.empty((self.ALLOC_SIZE, ), dtype=dtype) self._num_real = 0 else: self._token_ids = token_ids @@ -359,6 +376,11 @@ def get_real(self): """Get logical blocks.""" return self._token_ids[:self._num_real] + def resize(self, size: int): + """Set size.""" + assert size <= self._num_real + self._num_real = size + def __setitem__(self, *args, **kwargs): """Set values.""" return self.get_real().__setitem__(*args, **kwargs) @@ -391,9 +413,15 @@ def copy(self): return self.clone() +class HistoryDLLMMask(HistoryTokenIds): + + def __init__(self, token_ids: np.ndarray = None, dtype: np.dtype = DLLM_MASK_DTYPE): + super().__init__(token_ids=token_ids, dtype=dtype) + + class HistoryMultiModals: - def __init__(self, multimodals: MultiModalInputs): + def __init__(self, multimodals: MultiModalInputs = None): if multimodals is None: multimodals = dict() self.multimodals = multimodals @@ -449,6 +477,13 @@ def get_encoder_len(self, start=0, end=-1): return out_len +class UpdateTokenMode(enum.Enum): + """Update token mode.""" + INPUTS = enum.auto() + PREFILL = enum.auto() + DECODE = enum.auto() + + @dataclass class SchedulerSequence: """Scheduler message.""" @@ -477,14 +512,15 @@ class SchedulerSequence: def __post_init__(self): """Post init.""" - self._num_history_ids: int = 0 + self._seq_meta: SequenceMeta = self.session.seq_meta self._num_history_images: int = 0 - self._num_images: int = len(self.history_embeddings) + self._num_history_ids: int = 0 self._num_token_ids: int = len(self.history_cache) + # vlm + self._num_images: int = len(self.history_embeddings) self._num_history_cross: int = 0 self._num_cross: int = self.history_multimodals.get_encoder_len(0, self._num_token_ids) - self._seq_meta: SequenceMeta = self.session.seq_meta @property def block_size(self) -> int: @@ -530,6 +566,17 @@ def all_ids(self) -> np.ndarray: """Full token ids.""" return self.history_cache._token_ids[:self.num_all_ids] + @property + def valid_ids(self) -> np.ndarray: + """Valid token ids.""" + return self.history_cache._token_ids[:self.num_valid_ids] + + @property + def generated_ids(self) -> np.ndarray: + end = self.num_valid_ids + start = end - self.num_new_tokens + return self.history_cache._token_ids[start:end] + @property def num_history_ids(self): """Num history ids.""" @@ -539,6 +586,10 @@ def num_history_ids(self): def num_token_ids(self): return self._num_token_ids + @property + def num_valid_ids(self): + return self._num_history_ids + self._num_token_ids + @property def num_images(self): return self._num_images @@ -591,57 +642,87 @@ def get_input_multimodals(self): end = self.num_all_ids return self.history_multimodals.get_datas(start, end) + def record_event( + self, + event_type: EventType, + timestamp: Optional[float] = None, + ) -> None: + self.engine_events.append(EngineEvent.new_event(event_type, timestamp)) + def update_token_ids(self, token_ids: Tensor, multimodals: MultiModalInputs = None, embeddings: List[InputEmbeddings] = None, model_meta: Dict[str, Any] = None, - append_tokens: bool = False): + mode: UpdateTokenMode = UpdateTokenMode.INPUTS, + **kwargs): """Update token ids, old token ids will be added to history.""" - old_num_history_ids = self._num_history_ids + raise NotImplementedError('NotImplemented') - # update history - if not append_tokens: - self._num_history_ids += self._num_token_ids + def set_step(self, step: int): + """Set step.""" + raise NotImplementedError('NotImplemented') - # update history image nums - self._num_history_images += self._num_images - self._num_images = 0 - if embeddings is not None: - new_embeddings = [emb.move_position(self._num_history_ids) for emb in embeddings] - self._num_images = len(new_embeddings) - self.history_embeddings.append(new_embeddings) - # update multimodals - if multimodals is not None: - multimodals = HistoryMultiModals.update_multimodals(multimodals, self.num_all_ids) - self.history_multimodals.add_inputs(multimodals) +@dataclass +class SchedulerSequenceDefault(SchedulerSequence): - # cross + def _update_embeddings(self, embeddings: List[InputEmbeddings]): + """Update input embeddings.""" + self._num_history_images += self._num_images + if embeddings is None: + self._num_images = 0 + return + new_embeddings = [emb.move_position(self._num_history_ids) for emb in embeddings] + self._num_images = len(new_embeddings) + self.history_embeddings.append(new_embeddings) + + def _update_multimodals(self, old_num_history_ids: int, multimodals: MultiModalInputs): + """Update input multimodals.""" self._num_history_cross += self._num_cross - if multimodals is not None: - self._num_cross = self.history_multimodals.get_encoder_len(old_num_history_ids, self._num_history_ids) - else: + if multimodals is None: self._num_cross = 0 + return + multimodals = HistoryMultiModals.update_multimodals(multimodals, self.num_all_ids) + self.history_multimodals.add_inputs(multimodals) - if model_meta is not None: - self.model_meta = model_meta + # for mllama + self._num_cross = self.history_multimodals.get_encoder_len(old_num_history_ids, self._num_history_ids) - if isinstance(token_ids, Tensor): - token_ids = token_ids.numpy() - elif not isinstance(token_ids, np.ndarray): - token_ids = np.array(token_ids) - if token_ids.ndim == 0: - token_ids = token_ids[None] - if append_tokens: - self._num_token_ids += len(token_ids) + def update_token_ids(self, + token_ids: Tensor, + multimodals: MultiModalInputs = None, + embeddings: List[InputEmbeddings] = None, + model_meta: Dict[str, Any] = None, + mode: UpdateTokenMode = UpdateTokenMode.INPUTS, + **kwargs): + """Update token ids, old token ids will be added to history.""" + old_num_history_ids = self._num_history_ids + self.arrive_time = time.perf_counter() + + token_ids = _to_ndarray(token_ids) + + num_valid = len(token_ids) + + if mode == UpdateTokenMode.INPUTS: + self._num_token_ids += num_valid self.num_new_tokens = 0 else: - num_token_ids = len(token_ids) + self._num_history_ids += self._num_token_ids + num_token_ids = num_valid self._num_token_ids = num_token_ids self.num_new_tokens += num_token_ids + self.history_cache.append(token_ids) - self.arrive_time = time.perf_counter() + + # update history image nums + self._update_embeddings(embeddings) + + # update multimodals + self._update_multimodals(old_num_history_ids, multimodals) + + if model_meta is not None: + self.model_meta = model_meta def set_step(self, step: int): """Set step.""" @@ -663,9 +744,164 @@ def set_step(self, step: int): self._num_history_cross = self.history_multimodals.get_encoder_len(0, self.num_history_ids) self._num_cross = self.history_multimodals.get_encoder_len(self._num_history_ids, num_all_ids) - def record_event( - self, - event_type: EventType, - timestamp: Optional[float] = None, - ) -> None: - self.engine_events.append(EngineEvent.new_event(event_type, timestamp)) + +@dataclass +class SchedulerSequenceDLLM(SchedulerSequenceDefault): + + # For dllm + history_dllm_mask: HistoryDLLMMask = field(default_factory=HistoryDLLMMask) + + def __post_init__(self): + """Post init.""" + super().__post_init__() + self._num_valid_ids: int = len(self.history_cache) + + @property + def dllm_mask(self): + start = self.num_history_ids + end = start + self._num_token_ids + return self.history_dllm_mask._token_ids[start:end] + + @property + def num_valid_ids(self): + return self._num_valid_ids + + @property + def generated_ids(self) -> np.ndarray: + end = self.num_valid_ids + start = end - self.num_new_tokens + return self.history_cache._token_ids[start:end] + + @property + def all_dllm_mask(self): + return self.history_dllm_mask._token_ids[:self.num_all_ids] + + def set_stop_pos(self, pos: int): + block_sparse_size = self._seq_meta.block_sparse_size + val = block_sparse_size - pos - 1 + self._num_valid_ids -= val + self.num_new_tokens -= val + + def _update_token_ids_inputs(self, token_ids: np.ndarray, dllm_mask: np.ndarray): + """Append tokens.""" + num_tokens = len(token_ids) + block_sparse_size = self._seq_meta.block_sparse_size + dllm_mask_token = self._seq_meta.dllm_mask_token + new_token_ids = [token_ids] + new_dllm_mask = [dllm_mask] + + # add uncached tokens in token_ids + # for example, [cccc cccc uumm], the [uu] in last block is remain valid. + num_remain_valid = self.num_valid_ids - self.num_history_ids + if num_remain_valid != 0: + prev_token_ids = self.valid_ids[-num_remain_valid:] + prev_dllm_mask = np.full_like(prev_token_ids, DLLM_UNMASKED, dtype=DLLM_MASK_DTYPE) + new_token_ids = [prev_token_ids] + new_token_ids + new_dllm_mask = [prev_dllm_mask] + new_dllm_mask + self.history_cache.resize(self.num_history_ids) + self.history_dllm_mask.resize(self.num_history_ids) + num_tokens += num_remain_valid + + # pad to align with block_sparse_size + num_pad = (-num_tokens) % block_sparse_size + if num_pad > 0: + pad_ids = np.full_like(token_ids, dllm_mask_token, shape=(num_pad, )) + pad_mask = np.full_like(dllm_mask, DLLM_MASKED, shape=(num_pad, )) + new_token_ids += [pad_ids] + new_dllm_mask += [pad_mask] + + token_ids = np.concatenate(new_token_ids) + dllm_mask = np.concatenate(new_dllm_mask) + + assert len(token_ids) % block_sparse_size == 0 + + self.history_cache.append(token_ids) + self.history_dllm_mask.append(dllm_mask) + self._num_valid_ids = self.num_history_ids + num_tokens + self._num_token_ids = len(token_ids) + self.num_new_tokens = 0 + + def _update_token_ids_decode(self, token_ids: np.ndarray, dllm_mask: np.ndarray): + """Update token ids for decode.""" + num_tokens = len(token_ids) + block_sparse_size = self._seq_meta.block_sparse_size + dllm_mask_token = self._seq_meta.dllm_mask_token + assert num_tokens % block_sparse_size == 0 + num_history_ids = self.num_history_ids + + token_ids[dllm_mask == DLLM_MASKED] = dllm_mask_token + self.history_cache[num_history_ids:] = token_ids + self.history_dllm_mask[num_history_ids:] = dllm_mask + + # check if all blocks are cached + last_mask = dllm_mask[-block_sparse_size:] + is_unmasked = np.all(last_mask == DLLM_UNMASKED) + is_cached = np.all(last_mask == DLLM_CACHED) + + if is_unmasked: + num_new = block_sparse_size - self._num_valid_ids % block_sparse_size + self._num_valid_ids += num_new + self.num_new_tokens += num_new + + if is_cached: + # add new block + new_token_ids = np.full_like(token_ids, dllm_mask_token, shape=(block_sparse_size, )) + new_dllm_mask = np.full_like(dllm_mask, DLLM_MASKED, shape=(block_sparse_size, )) + self.history_cache.append(new_token_ids) + self.history_dllm_mask.append(new_dllm_mask) + self._num_history_ids += self._num_token_ids + self._num_token_ids = block_sparse_size + + def _update_token_ids_prefill(self, token_ids: np.ndarray, dllm_mask: np.ndarray): + """Update token ids for prefill.""" + block_sparse_size = self._seq_meta.block_sparse_size + num_history_ids = self.num_history_ids + + # fill input cache + if self.num_token_ids > block_sparse_size: + end = self.num_token_ids - block_sparse_size + self.history_dllm_mask[num_history_ids:end] = DLLM_CACHED + self._num_history_ids += end + self._num_token_ids -= end + + # decoding update + self._update_token_ids_decode(token_ids, dllm_mask) + + def update_token_ids(self, + token_ids: Tensor, + multimodals: MultiModalInputs = None, + embeddings: List[InputEmbeddings] = None, + model_meta: Dict[str, Any] = None, + dllm_mask: Tensor = None, + mode: UpdateTokenMode = UpdateTokenMode.INPUTS, + **kwargs): + """Update token ids, old token ids will be added to history.""" + old_num_history_ids = self._num_history_ids + self.arrive_time = time.perf_counter() + + token_ids: np.ndarray = _to_ndarray(token_ids) + if dllm_mask is None: + dllm_mask = np.full_like(token_ids, DLLM_UNMASKED, dtype=DLLM_MASK_DTYPE) + dllm_mask: np.ndarray = _to_ndarray(dllm_mask) + + if mode == UpdateTokenMode.INPUTS: + self._update_token_ids_inputs(token_ids, dllm_mask) + elif mode == UpdateTokenMode.PREFILL: + self._update_token_ids_prefill(token_ids, dllm_mask) + else: + self._update_token_ids_decode(token_ids, dllm_mask) + + # update history image nums + self._update_embeddings(embeddings) + + # update multimodals + self._update_multimodals(old_num_history_ids, multimodals) + + if model_meta is not None: + self.model_meta = model_meta + + +SEQ_CLS_MAP = dict( + llm=SchedulerSequenceDefault, + dllm=SchedulerSequenceDLLM, +) diff --git a/lmdeploy/pytorch/model_inputs.py b/lmdeploy/pytorch/model_inputs.py index 9b697754e0..151f8636c9 100644 --- a/lmdeploy/pytorch/model_inputs.py +++ b/lmdeploy/pytorch/model_inputs.py @@ -142,12 +142,14 @@ class ModelInputs: dp_meta: 'DPMeta' = None enable_microbatch: bool = False - def update(self, input_ids: torch.LongTensor): + def step(self, input_ids: torch.LongTensor, step_seqlens: torch.Tensor = None): """Update input ids.""" assert self.is_decoding - self.history_lengths = self.history_lengths + 1 - self.max_kv_seqlen += 1 - self.sum_kv_seqlen += self.seq_length.numel() + if step_seqlens is None: + step_seqlens = self.seq_length + self.history_lengths = self.history_lengths + step_seqlens + self.max_kv_seqlen += self.max_q_seqlen + self.sum_kv_seqlen += self.max_kv_seqlen * self.seq_length.numel() if input_ids.dim() == 1: input_ids = input_ids[None, :] self.input_ids = input_ids @@ -286,11 +288,20 @@ def make_dummy(cls, is_decoding: bool, device: str = 'cpu', dummy_block_id: int = 0, - vocab_size: int = 1): + vocab_size: int = 1, + build_ctx: 'BuildModelContext' = None): """Make dummy inputs.""" + model_paradigm = build_ctx.model_paradigm + if model_paradigm == 'dllm': + block_sparse_size = build_ctx.block_sparse_size + max_q_seqlen = block_sparse_size + else: + max_q_seqlen = 1 + num_tokens = batch_size * max_q_seqlen + max_kv_seqlen = max_q_seqlen input_ids = torch.randint(0, vocab_size, ( 1, - batch_size, + num_tokens, ), dtype=torch.long, device=device) seq_length = torch.ones((batch_size, ), dtype=torch.long, device=device) history_lengths = torch.zeros((batch_size, ), dtype=torch.long, device=device) @@ -304,8 +315,8 @@ def make_dummy(cls, block_offsets=block_offsets, is_decoding=is_decoding, num_ignored_history=num_ignored_history, - max_q_seqlen=1, - max_kv_seqlen=1, + max_q_seqlen=max_q_seqlen, + max_kv_seqlen=max_kv_seqlen, sum_kv_seqlen=batch_size, ) @@ -356,6 +367,7 @@ def new( model_config: ModelConfig, kv_caches: List = None, kv_quant_policy: Literal[0, 4, 8] = 0, + build_ctx: 'BuildModelContext' = None, ): """Build step context. @@ -377,7 +389,7 @@ def new( inputs.vision_inputs.get_inputs(history_seqlens, q_seqlens) # position ids - attention_mask, position_ids = cls.get_mask_and_position_ids(inputs) + attention_mask, position_ids = cls.get_mask_and_position_ids(inputs, build_ctx) position_ids = position_ids[None] # [num_tokens] -> [1, num_tokens] q_start_loc = q_seqlens.cumsum(0) - q_seqlens @@ -420,17 +432,27 @@ def new( return ret @classmethod - def get_mask_and_position_ids(cls, inputs: ModelInputs): + def get_mask_and_position_ids(cls, inputs: ModelInputs, build_ctx: 'BuildModelContext'): """Get position ids.""" q_seqlens = inputs.seq_length history_seqlens = inputs.history_lengths + model_paradigm = build_ctx.model_paradigm + is_decoding = inputs.is_decoding # decoding - if inputs.is_decoding: - attention_mask = torch.ones_like(q_seqlens)[:, None] - position_ids = history_seqlens.unsqueeze(-1).clone() - position_ids = position_ids.flatten() - return attention_mask, position_ids + if is_decoding: + if model_paradigm == 'dllm': + block_sparse_size = build_ctx.block_sparse_size + attention_mask = None + ranges = torch.arange(0, block_sparse_size, device=q_seqlens.device) + position_ids = history_seqlens[:, None] + ranges[None, :] + position_ids = position_ids.flatten() + return attention_mask, position_ids + else: + attention_mask = torch.ones_like(q_seqlens)[:, None] + position_ids = history_seqlens.unsqueeze(-1).clone() + position_ids = position_ids.flatten() + return attention_mask, position_ids num_tokens = inputs.input_ids.numel() max_q_seqlen = inputs.max_q_seqlen @@ -449,14 +471,24 @@ def get_mask_and_position_ids(cls, inputs: ModelInputs): return attention_mask, position_ids_1d +@dataclass +class BuildModelContext: + """Context for building model.""" + disable_vision_encoder: bool = False + block_sparse_size: int = 1 + model_paradigm: str = 'llm' + + class StepContextManager: - def __init__(self): + def __init__(self, build_ctx: BuildModelContext = None): self._current_ctx = None + build_ctx = build_ctx or BuildModelContext() + self.build_ctx = build_ctx - @staticmethod @record_function('build_step_context') def build_context( + self, inputs: ModelInputs, model_config: ModelConfig, kv_caches: List = None, @@ -468,6 +500,7 @@ def build_context( model_config, kv_caches, kv_quant_policy, + build_ctx=self.build_ctx, ) def set_context(self, ctx: StepContext): diff --git a/lmdeploy/pytorch/models/module_map.py b/lmdeploy/pytorch/models/module_map.py index 01d2c2fda6..7e3b9ebf4a 100644 --- a/lmdeploy/pytorch/models/module_map.py +++ b/lmdeploy/pytorch/models/module_map.py @@ -224,4 +224,9 @@ 'GptOssForCausalLM': f'{LMDEPLOY_PYTORCH_MODEL_PATH}.gpt_oss.GptOssForCausalLM', }) +# SDAR +MODULE_MAP.update({ + 'SDARForCausalLM': f'{LMDEPLOY_PYTORCH_MODEL_PATH}.sdar.SDARForCausalLM', +}) + CUSTOM_MODULE_MAP = dict() diff --git a/lmdeploy/pytorch/models/patch.py b/lmdeploy/pytorch/models/patch.py index 97ad271d9a..86aa5ea51f 100644 --- a/lmdeploy/pytorch/models/patch.py +++ b/lmdeploy/pytorch/models/patch.py @@ -9,9 +9,9 @@ from typing import Any, Dict import torch -from attr import dataclass from transformers.configuration_utils import PretrainedConfig +from lmdeploy.pytorch.model_inputs import BuildModelContext, StepContextManager from lmdeploy.utils import get_logger from ..config import ModelConfig @@ -188,10 +188,11 @@ def _get_model_class(config, module_map): def build_model_from_hf_config(model_config: PretrainedConfig, dtype: torch.dtype = None, device: torch.device = None, + ctx_mgr: StepContextManager = None, build_model_ctx: 'BuildModelContext' = None): """Build model from hf config.""" - from lmdeploy.pytorch.model_inputs import StepContextManager - ctx_mgr = StepContextManager() + if ctx_mgr is None: + ctx_mgr = StepContextManager(build_model_ctx) module_map = _get_module_map() if device is None: device = torch.device('cuda') @@ -333,12 +334,6 @@ def add_adapters(model: torch.nn.Module, return target_infos -@dataclass -class BuildModelContext: - """Context for building model.""" - disable_vision_encoder: bool = False - - BUILD_MODEL_CTX = BuildModelContext() diff --git a/lmdeploy/pytorch/models/sdar.py b/lmdeploy/pytorch/models/sdar.py new file mode 100644 index 0000000000..ab1f4ced29 --- /dev/null +++ b/lmdeploy/pytorch/models/sdar.py @@ -0,0 +1,405 @@ +# Copyright (c) OpenMMLab. All rights reserved. + +from typing import Any, Iterable, List, Optional, Tuple + +import torch +from torch import nn +from transformers.configuration_utils import PretrainedConfig + +from lmdeploy.pytorch.model_inputs import StepContext, StepContextManager +from lmdeploy.pytorch.nn import ApplyRotaryEmb, Attention, RMSNorm, SiluAndMul, build_rotary_embedding_from_config +from lmdeploy.pytorch.nn.linear import (build_down_linear, build_gateup_linear, build_o_proj, build_qkv_proj, + build_rowwise_linear) +from lmdeploy.pytorch.weight_loader.model_weight_loader import load_weight + +from .utils.cudagraph import CudaGraphMixin + + +class SDARAttention(nn.Module): + """attention.""" + + def __init__(self, config: PretrainedConfig, dtype: torch.dtype = None, device: torch.device = None): + super().__init__() + quantization_config = getattr(config, 'quantization_config', None) + num_heads = config.num_attention_heads + num_key_value_heads = config.num_key_value_heads + hidden_size = config.hidden_size + head_dim = getattr(config, 'head_dim', hidden_size // num_heads) + num_replicate_kv_heads = getattr(config, 'num_replicate_key_value_heads', 1) + # packed qkv + # Qwen3 uses 'config.attention_bias = False' for q/k/o projections + self.qkv_proj = build_qkv_proj(hidden_size, + num_q_heads=num_heads, + num_kv_heads=num_key_value_heads, + head_size=head_dim, + bias=config.attention_bias, + quant_config=quantization_config, + dtype=dtype, + device=device, + num_replicate_kv_heads=num_replicate_kv_heads) + + # rotary embedding + self.apply_rotary_pos_emb = ApplyRotaryEmb() + block_sparse_size = config.block_sparse_size + + # attention + self.attn_fwd = Attention( + num_heads, + head_dim, + num_kv_heads=num_key_value_heads, + v_head_size=head_dim, + sliding_window=config.sliding_window, + block_sparse_size=block_sparse_size, + ) + + # o_proj + self.o_proj = build_o_proj(num_heads * head_dim, + hidden_size, + bias=config.attention_bias, + quant_config=quantization_config, + dtype=dtype, + device=device, + is_tp=True) + + # q, k norm + self.q_norm = RMSNorm(head_dim, config.rms_norm_eps, dtype=dtype, device=device) + self.k_norm = RMSNorm(head_dim, config.rms_norm_eps, dtype=dtype, device=device) + + def forward( + self, + hidden_states: torch.Tensor, + rotary_pos_emb: Tuple[torch.FloatTensor, torch.FloatTensor], + past_key_value: Optional[Tuple[torch.Tensor]] = None, + attn_metadata: Any = None, + ): + """Rewrite of LlamaAttention.forward.""" + # qkv proj + qkv_states = self.qkv_proj(hidden_states) + # (-1, heads, head_dim) + qkv_states = qkv_states.flatten(0, -2) + query_states, key_states, value_states = self.qkv_proj.split_qkv(qkv_states) + + # apply q, k norm + query_states = self.q_norm(query_states) + key_states = self.k_norm(key_states) + + # apply rotary embedding + cos, sin = rotary_pos_emb + query_states, key_states = self.apply_rotary_pos_emb( + query_states, + key_states, + cos, + sin, + ) + # attention + attn_output = self.attn_fwd( + query_states, + key_states, + value_states, + past_key_value[0], + past_key_value[1], + attn_metadata, + k_scales_zeros=None if len(past_key_value) == 2 else past_key_value[2], + v_scales_zeros=None if len(past_key_value) == 2 else past_key_value[3], + inplace=True, + ) + attn_output = attn_output.reshape(*hidden_states.shape[:-1], -1) + + # o proj + attn_output = self.o_proj(attn_output) + return attn_output + + +class SDARMLP(nn.Module): + """mlp.""" + + def __init__(self, config: PretrainedConfig, dtype: torch.dtype = None, device: torch.device = None): + super().__init__() + quantization_config = getattr(config, 'quantization_config', None) + # gate up + self.gate_up_proj = build_gateup_linear( + config.hidden_size, + [config.intermediate_size, config.intermediate_size], + bias=False, + dtype=dtype, + device=device, + quant_config=quantization_config, + is_tp=True, + ) + + # silu and mul + self.act_fn = SiluAndMul(inplace=True) + + # down + self.down_proj = build_down_linear(config.intermediate_size, + config.hidden_size, + bias=False, + quant_config=quantization_config, + dtype=dtype, + device=device, + is_tp=True) + + def forward(self, x): + """forward.""" + gate_up = self.gate_up_proj(x) + act = self.act_fn(gate_up) + return self.down_proj(act) + + +class SDARDecoderLayer(nn.Module): + """Decode layer.""" + + def __init__(self, + config: PretrainedConfig, + layer_idx: int, + dtype: torch.dtype = None, + device: torch.device = None): + super().__init__() + self.layer_idx = layer_idx + quantization_config = getattr(config, 'quantization_config', None) + + # build attention layer + self.self_attn = SDARAttention(config, dtype=dtype, device=device) + + # build MLP + self.mlp = SDARMLP(config, dtype=dtype, device=device) + + # build input layer norm + self.input_layernorm = RMSNorm(config.hidden_size, + config.rms_norm_eps, + quant_config=quantization_config, + dtype=dtype, + device=device) + + # build attention layer norm + self.post_attention_layernorm = RMSNorm(config.hidden_size, + config.rms_norm_eps, + quant_config=quantization_config, + dtype=dtype, + device=device) + + def forward( + self, + hidden_states: torch.Tensor, + rotary_pos_emb: Tuple[torch.FloatTensor, torch.FloatTensor], + past_key_value: Optional[List[torch.FloatTensor]], + residual: Optional[torch.Tensor] = None, + attn_metadata: Any = None, + ): + + if residual is None: + residual = hidden_states + hidden_states = self.input_layernorm(hidden_states) + else: + hidden_states, residual = self.input_layernorm(hidden_states, residual) + + # Self Attention + hidden_states = self.self_attn( + hidden_states=hidden_states, + rotary_pos_emb=rotary_pos_emb, + past_key_value=past_key_value, + attn_metadata=attn_metadata, + ) + + # Fully Connected + hidden_states, residual = self.post_attention_layernorm(hidden_states, residual) + hidden_states = self.mlp(hidden_states) + + outputs = (hidden_states, residual) + return outputs + + +class SDARModel(nn.Module): + """model.""" + + def __init__(self, config: PretrainedConfig, dtype: torch.dtype = None, device: torch.device = None): + super().__init__() + self.padding_idx = config.pad_token_id + self.vocab_size = config.vocab_size + + self.embed_tokens = nn.Embedding(config.vocab_size, + config.hidden_size, + self.padding_idx, + dtype=dtype, + device=device) + + # build all decode layers + self.layers = nn.ModuleList([ + SDARDecoderLayer(config, layer_idx, dtype=dtype, device=device) + for layer_idx in range(config.num_hidden_layers) + ]) + + # build norm + self.norm = RMSNorm(config.hidden_size, config.rms_norm_eps, dtype=dtype, device=device) + + # build rotary embedding + self.rotary_emb = build_rotary_embedding_from_config(config) + + def forward( + self, + input_ids: torch.LongTensor = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[List[torch.FloatTensor]] = None, + attn_metadata: Any = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + ): + """Rewrite of forward.""" + + # token embedding + if inputs_embeds is None: + inputs_embeds = self.embed_tokens(input_ids) + + hidden_states = inputs_embeds + + # rotary embedding + cos, sin = self.rotary_emb(hidden_states, position_ids) + cos, sin = cos[0], sin[0] + rotary_pos_emb = (cos, sin) + + # decoding + residual = None + for idx, decoder_layer in enumerate(self.layers): + past_key_value = past_key_values[idx] + hidden_states, residual = decoder_layer( + hidden_states, + rotary_pos_emb=rotary_pos_emb, + past_key_value=past_key_value, + residual=residual, + attn_metadata=attn_metadata, + ) + + # norm + hidden_states, _ = self.norm(hidden_states, residual) + + return hidden_states + + def get_input_embeddings(self): + """Get input embeddings.""" + return self.embed_tokens + + +class SDARForCausalLM(nn.Module, CudaGraphMixin): + """ModelForCausalLM.""" + + packed_modules_mapping = { + 'qkv_proj': [ + 'q_proj', + 'k_proj', + 'v_proj', + ], + 'gate_up_proj': [ + 'gate_proj', + 'up_proj', + ], + } + + def __init__(self, + config: PretrainedConfig, + ctx_mgr: StepContextManager, + dtype: torch.dtype = None, + device: torch.device = None): + super().__init__() + self.config = config + self.ctx_mgr = ctx_mgr + config.block_sparse_size = ctx_mgr.build_ctx.block_sparse_size + # build model + self.model = SDARModel(config, dtype=dtype, device=device) + # build lm_head + self.lm_head = build_rowwise_linear(config.hidden_size, + config.vocab_size, + bias=False, + dtype=dtype, + device=device) + + def forward( + self, + input_ids: torch.Tensor, + position_ids: torch.Tensor, + past_key_values: List[List[torch.Tensor]], + attn_metadata: Any = None, + inputs_embeds: torch.Tensor = None, + **kwargs, + ): + """Model forward, return logits.""" + hidden_states = self.model( + input_ids=input_ids, + position_ids=position_ids, + past_key_values=past_key_values, + attn_metadata=attn_metadata, + inputs_embeds=inputs_embeds, + ) + return hidden_states + + def get_logits(self, hidden_states: torch.Tensor): + """Compute logits of the model output.""" + return self.lm_head(hidden_states) + + def update_weights(self): + """Update weights.""" + if self.config.tie_word_embeddings: + self.lm_head.weight = self.model.embed_tokens.weight + + def get_input_embeddings(self): + """Get input embeddings.""" + return self.model.get_input_embeddings() + + def prepare_inputs_for_generation( + self, + past_key_values: List[List[torch.Tensor]], + inputs_embeds: Optional[torch.Tensor] = None, + context: StepContext = None, + ): + """Prepare input.""" + # get input_ids, position_ids and attention metadatas + input_ids = context.input_ids + position_ids = context.position_ids + attn_metadata = context.attn_metadata + + # process vision embeddings + vision_embeddings = context.input_embeddings + vision_embedding_indexing = context.input_embedding_indexing + if vision_embeddings is not None and len(vision_embeddings) > 0: + if inputs_embeds is None: + inputs_embeds = self.get_input_embeddings()(input_ids) + inputs_embeds[:, vision_embedding_indexing, :] = vision_embeddings.to(inputs_embeds) + + # inputs of forward + return dict( + input_ids=input_ids, + position_ids=position_ids, + past_key_values=past_key_values, + attn_metadata=attn_metadata, + inputs_embeds=inputs_embeds, + ) + + def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): + """Load weights.""" + # modify from vllm + stacked_params_mapping = [ + # (param_name, shard_name, shard_id) + ('.qkv_proj', '.q_proj', 'q'), + ('.qkv_proj', '.k_proj', 'k'), + ('.qkv_proj', '.v_proj', 'v'), + ('.gate_up_proj', '.gate_proj', 0), + ('.gate_up_proj', '.up_proj', 1), + ] + + params_dict = dict(self.named_parameters()) + for name, loaded_weight in weights: + if 'rotary_emb.inv_freq' in name: + continue + if ('rotary_emb.cos_cached' in name or 'rotary_emb.sin_cached' in name): + continue + if self.config.tie_word_embeddings and 'lm_head.weight' in name: + continue + + for (param_name, weight_name, shard_id) in stacked_params_mapping: + if weight_name not in name: + continue + name = name.replace(weight_name, param_name) + param = params_dict[name] + load_weight(param, loaded_weight, shard_id=shard_id) + break + else: + param = params_dict[name] + load_weight(param, loaded_weight) diff --git a/lmdeploy/pytorch/models/utils/cudagraph.py b/lmdeploy/pytorch/models/utils/cudagraph.py index b86adbbb89..9901bf443c 100644 --- a/lmdeploy/pytorch/models/utils/cudagraph.py +++ b/lmdeploy/pytorch/models/utils/cudagraph.py @@ -142,15 +142,15 @@ def fill_buffers_cudagraph(self, graph_meta: CudaGraphMeta, input_ids: Tensor, p new_inputs['cross_attn_metadata'] = cross_attn_metadata if is_decoding: - new_inputs['input_ids'] = input_buffers['input_ids'][:, :new_batch_size] - new_inputs['position_ids'] = input_buffers['position_ids'][:, :new_batch_size] + new_inputs['input_ids'] = input_buffers['input_ids'][:, :num_tokens] + new_inputs['position_ids'] = input_buffers['position_ids'][:, :num_tokens] else: new_inputs['input_ids'] = input_buffers['input_ids'] new_inputs['position_ids'] = input_buffers['position_ids'] if inputs_embeds is not None: if is_decoding: - new_inputs['inputs_embeds'] = input_buffers['inputs_embeds'][:, :new_batch_size] + new_inputs['inputs_embeds'] = input_buffers['inputs_embeds'][:, :num_tokens] else: new_inputs['inputs_embeds'] = input_buffers['inputs_embeds'] diff --git a/lmdeploy/pytorch/nn/attention.py b/lmdeploy/pytorch/nn/attention.py index ef5d831d51..7a1654db4b 100644 --- a/lmdeploy/pytorch/nn/attention.py +++ b/lmdeploy/pytorch/nn/attention.py @@ -37,6 +37,7 @@ def __init__( causal: bool = True, use_flash_mla: bool = False, learnable_sink: bool = False, + block_sparse_size: int = 1, **kwargs, ): super().__init__() @@ -61,6 +62,7 @@ def __init__( causal=causal, use_flash_mla=use_flash_mla, learnable_sink=learnable_sink, + block_sparse_size=block_sparse_size, **kwargs, ) diff --git a/lmdeploy/pytorch/paging/block_trie.py b/lmdeploy/pytorch/paging/block_trie.py index b5908dde54..d00690af24 100644 --- a/lmdeploy/pytorch/paging/block_trie.py +++ b/lmdeploy/pytorch/paging/block_trie.py @@ -81,7 +81,7 @@ def __match_success(node: Node): curr = node num_matched += block_size - while num_matched + block_size < seq.num_all_ids: + while num_matched + block_size < seq.num_valid_ids: curr_tokens = seq.history_cache[num_matched:num_matched + block_size] key = hash(('random', tuple(curr_tokens))) @@ -116,9 +116,9 @@ def allocate(self, seq: SchedulerSequence): logical_blocks.last_shared_node = node num_matched = node.num_matched - num_all_ids = seq.num_all_ids + num_valid_ids = seq.num_valid_ids - if num_matched + block_size > num_all_ids: + if num_matched + block_size > num_valid_ids: return if len(node.children) == 0 and node.parent is not None: @@ -127,7 +127,7 @@ def allocate(self, seq: SchedulerSequence): block_id = num_matched // block_size blocks = [] free_blocks = [] - while num_matched + block_size <= num_all_ids: + while num_matched + block_size <= num_valid_ids: curr_tokens = seq.history_cache[num_matched:num_matched + block_size] block = logical_blocks[block_id] diff --git a/lmdeploy/pytorch/paging/scheduler.py b/lmdeploy/pytorch/paging/scheduler.py index bf36c74321..ef542974b6 100644 --- a/lmdeploy/pytorch/paging/scheduler.py +++ b/lmdeploy/pytorch/paging/scheduler.py @@ -36,7 +36,10 @@ class Scheduler: cache_config (CacheConfig): The config of cache info. """ - def __init__(self, scheduler_config: SchedulerConfig, cache_config: CacheConfig) -> None: + def __init__(self, + scheduler_config: SchedulerConfig, + cache_config: CacheConfig, + seq_meta: SequenceMeta = None) -> None: self.scheduler_config = scheduler_config self.cache_config = cache_config @@ -50,7 +53,7 @@ def __init__(self, scheduler_config: SchedulerConfig, cache_config: CacheConfig) self.eviction_helper = self.build_eviction_helper(self.scheduler_config.eviction_type) - seq_meta = SequenceMeta(self.cache_config.block_size) + seq_meta = seq_meta or SequenceMeta(self.cache_config.block_size) self.seq_manager = SequenceManager(seq_meta) @property From 48a0137fa46207670292ab6803e61c8b01ee1ca6 Mon Sep 17 00:00:00 2001 From: grimoire Date: Mon, 1 Sep 2025 16:34:42 +0800 Subject: [PATCH 04/29] fix max_new_tokens;update profiler --- benchmark/profile_throughput.py | 2 ++ lmdeploy/cli/utils.py | 5 +++++ lmdeploy/pytorch/engine/model_agent.py | 6 ++++-- lmdeploy/pytorch/engine/unmasking.py | 2 ++ 4 files changed, 13 insertions(+), 2 deletions(-) diff --git a/benchmark/profile_throughput.py b/benchmark/profile_throughput.py index 91d3d86e58..f11e51a35d 100644 --- a/benchmark/profile_throughput.py +++ b/benchmark/profile_throughput.py @@ -307,6 +307,7 @@ def parse_args(): # pytorch engine args pt_group = parser.add_argument_group('PyTorch engine arguments') ArgumentHelper.eager_mode(pt_group) + ArgumentHelper.block_sparse_size(pt_group) tp_act = ArgumentHelper.tp(pt_group) cache_count_act = ArgumentHelper.cache_max_entry_count(pt_group) @@ -363,6 +364,7 @@ def main(): quant_policy=args.quant_policy, dtype=args.dtype, distributed_executor_backend=args.distributed_executor_backend, + block_sparse_size=args.block_sparse_size, ) if args.use_uvloop: diff --git a/lmdeploy/cli/utils.py b/lmdeploy/cli/utils.py index 111af08fd4..2aa543bd5e 100644 --- a/lmdeploy/cli/utils.py +++ b/lmdeploy/cli/utils.py @@ -610,6 +610,11 @@ def logprobs_mode(parser): choices=[None, 'raw_logits', 'raw_logprobs'], help='The mode of logprobs.') + @staticmethod + def block_sparse_size(parser): + """block_sparse_size for dllm.""" + return parser.add_argument('--block-sparse-size', type=int, default=1, help='Block sparse size for dllm') + # adapted from https://github.com/vllm-project/vllm/blob/main/vllm/utils/__init__.py class FlexibleArgumentParser(argparse.ArgumentParser): diff --git a/lmdeploy/pytorch/engine/model_agent.py b/lmdeploy/pytorch/engine/model_agent.py index 7c6db80fce..df23d68c63 100644 --- a/lmdeploy/pytorch/engine/model_agent.py +++ b/lmdeploy/pytorch/engine/model_agent.py @@ -251,9 +251,9 @@ def _batch_stopping_criteria_default(token_ids: torch.Tensor, stop_words: torch. """Batched stopping criteria.""" num_appendable_ids = num_appendable_ids - 1 stopped = num_appendable_ids <= 0 + stop_pos = torch.zeros_like(num_appendable_ids) if stop_words is not None: sw_stopped = (token_ids[:, None] == stop_words).any(1) - stop_pos = torch.zeros_like(num_appendable_ids) stopped = stopped | sw_stopped one_ids = torch.clamp_max(num_appendable_ids, 0) num_appendable_ids = torch.where(sw_stopped, one_ids, num_appendable_ids) @@ -278,10 +278,12 @@ def _batch_stopping_criteria_dllm(token_ids: torch.Tensor, stop_words: torch.Ten is_unmasked = (dllm_mask == DLLM_UNMASKED).all(dim=1) num_appendable_ids -= is_unmasked * block_sparse_size stopped = num_appendable_ids <= 0 + stop_pos = block_sparse_size - 1 + num_appendable_ids if stop_words is not None: sw_stopped = (token_ids[:, None] == stop_words).any(1) sw_stopped = sw_stopped.view(batch_size, block_sparse_size) - stop_pos = sw_stopped.int().argmax(1) + sw_stop_pos = sw_stopped.int().argmax(1) + stop_pos = torch.where(stopped, stop_pos, sw_stop_pos) sw_stopped = sw_stopped.any(dim=1) sw_stopped = sw_stopped & is_unmasked stopped = stopped | sw_stopped diff --git a/lmdeploy/pytorch/engine/unmasking.py b/lmdeploy/pytorch/engine/unmasking.py index def33dd0d3..24697520d8 100644 --- a/lmdeploy/pytorch/engine/unmasking.py +++ b/lmdeploy/pytorch/engine/unmasking.py @@ -3,6 +3,7 @@ from typing import Optional import torch +from torch.profiler import record_function @dataclasses.dataclass @@ -46,6 +47,7 @@ def low_confidence_static(self, logits: torch.Tensor, token_ids: torch.Tensor, d dllm_mask = torch.where(is_masked, dllm_unmasked, dllm_mask) return dllm_mask.flatten() + @record_function('unmasking') def __call__(self, logits: torch.Tensor, input_ids: torch.Tensor, token_ids: torch.Tensor, dllm_mask: torch.Tensor): """call.""" strategy = self.meta.strategy From 6e8f4c52f173dada91966b4aef97b9612d0643ee Mon Sep 17 00:00:00 2001 From: grimoire Date: Mon, 1 Sep 2025 18:08:30 +0800 Subject: [PATCH 05/29] add args --- benchmark/profile_throughput.py | 10 +++- lmdeploy/cli/serve.py | 8 +++ lmdeploy/cli/utils.py | 31 +++++++++- lmdeploy/messages.py | 14 ++++- lmdeploy/pytorch/config.py | 21 ++++++- lmdeploy/pytorch/consts.py | 5 ++ lmdeploy/pytorch/engine/engine.py | 2 +- lmdeploy/pytorch/engine/model_agent.py | 24 ++++---- lmdeploy/pytorch/engine/unmasking.py | 82 +++++++++++++++++++------- lmdeploy/pytorch/messages.py | 7 ++- 10 files changed, 158 insertions(+), 46 deletions(-) create mode 100644 lmdeploy/pytorch/consts.py diff --git a/benchmark/profile_throughput.py b/benchmark/profile_throughput.py index f11e51a35d..79db48aebf 100644 --- a/benchmark/profile_throughput.py +++ b/benchmark/profile_throughput.py @@ -307,7 +307,10 @@ def parse_args(): # pytorch engine args pt_group = parser.add_argument_group('PyTorch engine arguments') ArgumentHelper.eager_mode(pt_group) - ArgumentHelper.block_sparse_size(pt_group) + ArgumentHelper.dllm_block_length(pt_group) + ArgumentHelper.dllm_unmasking_strategy(pt_group) + ArgumentHelper.dllm_denoising_steps(pt_group) + ArgumentHelper.dllm_confidence_threshold(pt_group) tp_act = ArgumentHelper.tp(pt_group) cache_count_act = ArgumentHelper.cache_max_entry_count(pt_group) @@ -364,7 +367,10 @@ def main(): quant_policy=args.quant_policy, dtype=args.dtype, distributed_executor_backend=args.distributed_executor_backend, - block_sparse_size=args.block_sparse_size, + dllm_block_length=args.dllm_block_length, + dllm_unmasking_strategy=args.dllm_unmasking_strategy, + dllm_denoising_steps=args.dllm_denoising_steps, + dllm_confidence_threshold=args.dllm_confidence_threshold, ) if args.use_uvloop: diff --git a/lmdeploy/cli/serve.py b/lmdeploy/cli/serve.py index 0db760585a..4caa8e0a22 100644 --- a/lmdeploy/cli/serve.py +++ b/lmdeploy/cli/serve.py @@ -92,6 +92,10 @@ def add_parser_api_server(): ArgumentHelper.eager_mode(pt_group) ArgumentHelper.disable_vision_encoder(pt_group) ArgumentHelper.logprobs_mode(pt_group) + ArgumentHelper.dllm_block_length(pt_group) + ArgumentHelper.dllm_unmasking_strategy(pt_group) + ArgumentHelper.dllm_denoising_steps(pt_group) + ArgumentHelper.dllm_confidence_threshold(pt_group) # common engine args dtype_act = ArgumentHelper.dtype(pt_group) @@ -219,6 +223,10 @@ def api_server(args): hf_overrides=args.hf_overrides, disable_vision_encoder=args.disable_vision_encoder, logprobs_mode=args.logprobs_mode, + dllm_block_length=args.dllm_block_length, + dllm_unmasking_strategy=args.dllm_unmasking_strategy, + dllm_denoising_steps=args.dllm_denoising_steps, + dllm_confidence_threshold=args.dllm_confidence_threshold, ) else: from lmdeploy.messages import TurbomindEngineConfig diff --git a/lmdeploy/cli/utils.py b/lmdeploy/cli/utils.py index 2aa543bd5e..bdfe2b5217 100644 --- a/lmdeploy/cli/utils.py +++ b/lmdeploy/cli/utils.py @@ -611,9 +611,34 @@ def logprobs_mode(parser): help='The mode of logprobs.') @staticmethod - def block_sparse_size(parser): - """block_sparse_size for dllm.""" - return parser.add_argument('--block-sparse-size', type=int, default=1, help='Block sparse size for dllm') + def dllm_block_length(parser): + """dllm_block_length for dllm.""" + return parser.add_argument('--dllm-block-length', type=int, default=1, help='Block length for dllm') + + @staticmethod + def dllm_unmasking_strategy(parser): + """Dllm unmasking strategy.""" + return parser.add_argument('--dllm-unmasking-strategy', + type=str, + default='low_confidence_dynamic', + choices=['low_confidence_dynamic', 'low_confidence_static', 'sequential'], + help='The unmasking strategy for dllm.') + + @staticmethod + def dllm_denoising_steps(parser): + """Dllm denoising steps.""" + return parser.add_argument('--dllm-denoising-steps', + type=int, + default=None, + help='The number of denoising steps for dllm.') + + @staticmethod + def dllm_confidence_threshold(parser): + """Dllm confidence threshold.""" + return parser.add_argument('--dllm-confidence-threshold', + type=float, + default=0.85, + help='The confidence threshold for dllm.') # adapted from https://github.com/vllm-project/vllm/blob/main/vllm/utils/__init__.py diff --git a/lmdeploy/messages.py b/lmdeploy/messages.py index 2448d6e53f..51a3748720 100644 --- a/lmdeploy/messages.py +++ b/lmdeploy/messages.py @@ -334,8 +334,13 @@ class PytorchEngineConfig: It can be used to override the default config of the model, disable_vision_encoder (bool): Whether to disable loading vision encoder. Default to False. - block_sparse_size (int): Block size of block diffusion model. logprobs_mode (str): The mode of logprob, options: ['raw_logits', 'raw_logprobs'] + dllm_block_length (int): Block size of block diffusion model. + dllm_unmasking_strategy (str): Dllm unmasking strategy, options: + ['low_confidence_dynamic', 'low_confidence_static', 'sequential']. + dllm_denoising_steps (int): Dllm denoising steps. + dllm_confidence_threshold (float): dllm unmasking threshold for + dynamic unmasking. """ dtype: str = 'auto' tp: int = 1 @@ -369,9 +374,14 @@ class PytorchEngineConfig: enable_metrics: bool = False hf_overrides: Optional[Dict[str, Any]] = None disable_vision_encoder: bool = False - block_sparse_size: int = 1 logprobs_mode: str = None + # dllm + dllm_block_length: int = 1 + dllm_unmasking_strategy: str = 'low_confidence_dynamic' + dllm_denoising_steps: int = None + dllm_confidence_threshold: float = 0.85 + role: EngineRole = EngineRole.Hybrid migration_backend: MigrationBackend = MigrationBackend.DLSlime diff --git a/lmdeploy/pytorch/config.py b/lmdeploy/pytorch/config.py index 6cb6d5cca7..1326c6e4b2 100644 --- a/lmdeploy/pytorch/config.py +++ b/lmdeploy/pytorch/config.py @@ -287,6 +287,14 @@ def from_hf_config(cls, return model_config +@dataclass +class DLLMConfig: + block_sparse_size: int = 1 + unmasking_strategy: str = 'low_confidence_dynamic' + denoising_steps: int = None + confidence_threshold: float = 0.85 + + @dataclass class MiscConfig: prefill_interval: int = 16 @@ -297,16 +305,25 @@ class MiscConfig: disable_vision_encoder: bool = False logprobs_mode: str = None block_sparse_size: int = 1 + dllm_config: DLLMConfig = None @classmethod def from_engine_config(cls, engine_config: PytorchEngineConfig): """From engine config.""" + denoising_steps = engine_config.dllm_denoising_steps + if denoising_steps is None: + denoising_steps = engine_config.dllm_block_length // 2 + dllm_config = DLLMConfig(block_sparse_size=engine_config.dllm_block_length, + unmasking_strategy=engine_config.dllm_unmasking_strategy, + denoising_steps=denoising_steps, + confidence_threshold=engine_config.dllm_confidence_threshold) misc_config = cls(custom_module_map=engine_config.custom_module_map, empty_init=engine_config.empty_init, prefill_interval=engine_config.prefill_interval, model_format=engine_config.model_format, hf_overrides=engine_config.hf_overrides, disable_vision_encoder=engine_config.disable_vision_encoder, - block_sparse_size=engine_config.block_sparse_size, - logprobs_mode=engine_config.logprobs_mode) + block_sparse_size=engine_config.dllm_block_length, + logprobs_mode=engine_config.logprobs_mode, + dllm_config=dllm_config) return misc_config diff --git a/lmdeploy/pytorch/consts.py b/lmdeploy/pytorch/consts.py new file mode 100644 index 0000000000..93dc2f4d71 --- /dev/null +++ b/lmdeploy/pytorch/consts.py @@ -0,0 +1,5 @@ +# Copyright (c) OpenMMLab. All rights reserved. +# dllm +DLLM_MASKED = 0 +DLLM_UNMASKED = 1 +DLLM_CACHED = 2 diff --git a/lmdeploy/pytorch/engine/engine.py b/lmdeploy/pytorch/engine/engine.py index bfb794a8e6..4c87e025a8 100644 --- a/lmdeploy/pytorch/engine/engine.py +++ b/lmdeploy/pytorch/engine/engine.py @@ -144,7 +144,7 @@ def _build_seq_meta(cache_config: CacheConfig, model_config: ModelConfig, engine seq_meta = SequenceMeta(cache_config.block_size, model_paradigm=model_config.model_paradigm, - block_sparse_size=engine_config.block_sparse_size, + block_sparse_size=engine_config.dllm_block_length, dllm_mask_token=model_config.dllm_mask_token) return seq_meta diff --git a/lmdeploy/pytorch/engine/model_agent.py b/lmdeploy/pytorch/engine/model_agent.py index df23d68c63..76f69fbe1f 100644 --- a/lmdeploy/pytorch/engine/model_agent.py +++ b/lmdeploy/pytorch/engine/model_agent.py @@ -14,6 +14,7 @@ import torch.distributed as dist from torch.profiler import ProfilerActivity, profile, record_function +from lmdeploy.pytorch import consts from lmdeploy.pytorch.disagg.config import EngineRole from lmdeploy.serve.openai.protocol import UpdateParamsRequest from lmdeploy.utils import get_logger @@ -28,7 +29,7 @@ from ..weight_loader.model_weight_loader import load_model_weights from .cache_engine import CacheEngine from .logits_process import FusedLogitsProcessor, SamplingInputs -from .unmasking import UnmaskingMeta, UnmaskingProcessor +from .unmasking import UnmaskingProcessor logger = get_logger('lmdeploy') @@ -260,11 +261,6 @@ def _batch_stopping_criteria_default(token_ids: torch.Tensor, stop_words: torch. return stopped, stop_pos, num_appendable_ids -DLLM_MASKED = 0 -DLLM_UNMASKED = 1 -DLLM_CACHED = 2 - - def _batch_stopping_criteria_dllm(token_ids: torch.Tensor, stop_words: torch.Tensor, num_appendable_ids: torch.Tensor, dllm_mask: torch.Tensor): """Batched stopping criteria.""" @@ -275,7 +271,7 @@ def _batch_stopping_criteria_dllm(token_ids: torch.Tensor, stop_words: torch.Ten return _batch_stopping_criteria_default(token_ids, stop_words, num_appendable_ids) dllm_mask = dllm_mask.view(batch_size, block_sparse_size) - is_unmasked = (dllm_mask == DLLM_UNMASKED).all(dim=1) + is_unmasked = (dllm_mask == consts.DLLM_UNMASKED).all(dim=1) num_appendable_ids -= is_unmasked * block_sparse_size stopped = num_appendable_ids <= 0 stop_pos = block_sparse_size - 1 + num_appendable_ids @@ -406,9 +402,11 @@ def __init__(self, def _build_unmasking_processor(self): """Build unmasking processor.""" - strategy = 'low_confidence_static' if self.model_config.model_paradigm == 'dllm' else None - unmasking_processor = UnmaskingProcessor( - UnmaskingMeta(strategy=strategy, block_sparse_size=self.misc_config.block_sparse_size, topk=2)) + # block_sparse_size = self.misc_config.block_sparse_size + # strategy = 'low_confidence_dynamic' if self.model_config.model_paradigm == 'dllm' else None + # denoising_steps = max(1, block_sparse_size // 2) + dllm_config = self.misc_config.dllm_config + unmasking_processor = UnmaskingProcessor(dllm_config) return unmasking_processor @contextmanager @@ -718,11 +716,11 @@ def __update_dllm(next_token_ids: torch.Tensor, dllm_mask: torch.Tensor, seqlens dllm_mask = dllm_mask.view(-1, block_sparse_size).clone() # flags - is_cached = (dllm_mask == DLLM_CACHED).all(dim=1) + is_cached = (dllm_mask == consts.DLLM_CACHED).all(dim=1) - is_masked = (dllm_mask == DLLM_MASKED) + is_masked = (dllm_mask == consts.DLLM_MASKED) next_token_ids[is_cached[:, None] | is_masked] = dllm_mask_token - dllm_mask[is_cached] = DLLM_MASKED + dllm_mask[is_cached] = consts.DLLM_MASKED seqlens = torch.where(is_cached.view(-1), seqlens, seqlens.new_zeros((1, ))) return next_token_ids.flatten(), dllm_mask.flatten(), seqlens diff --git a/lmdeploy/pytorch/engine/unmasking.py b/lmdeploy/pytorch/engine/unmasking.py index 24697520d8..0433b5a4b1 100644 --- a/lmdeploy/pytorch/engine/unmasking.py +++ b/lmdeploy/pytorch/engine/unmasking.py @@ -1,28 +1,19 @@ # Copyright (c) OpenMMLab. All rights reserved. -import dataclasses -from typing import Optional - import torch from torch.profiler import record_function +from lmdeploy.pytorch import consts +from lmdeploy.pytorch.config import DLLMConfig -@dataclasses.dataclass -class UnmaskingMeta: - strategy: Optional[str] - block_sparse_size: int - topk: int = 1 - threshold: float = 0 - - -DLLM_MASKED = 0 -DLLM_UNMASKED = 1 -DLLM_CACHED = 2 +DLLM_MASKED = consts.DLLM_MASKED +DLLM_UNMASKED = consts.DLLM_UNMASKED +DLLM_CACHED = consts.DLLM_CACHED class UnmaskingProcessor: - def __init__(self, meta: UnmaskingMeta): - self.meta = meta + def __init__(self, dllm_config: DLLMConfig): + self.dllm_config = dllm_config def _get_scores(self, logits: torch.Tensor, token_ids: torch.Tensor): """Get scores.""" @@ -30,10 +21,20 @@ def _get_scores(self, logits: torch.Tensor, token_ids: torch.Tensor): scores = scores.gather(-1, token_ids.unsqueeze(-1)).flatten() return scores + def _get_denoise_num(self): + """Get denoise num.""" + block_sparse_size = self.dllm_config.block_sparse_size + denoising_steps = self.dllm_config.denoising_steps + if denoising_steps is None: + denoising_steps = block_sparse_size + num = block_sparse_size // self.dllm_config.denoising_steps + num = max(1, min(num, block_sparse_size)) + return num + def low_confidence_static(self, logits: torch.Tensor, token_ids: torch.Tensor, dllm_mask: torch.Tensor): """static.""" - block_sparse_size = self.meta.block_sparse_size - topk = min(self.meta.topk, block_sparse_size) + block_sparse_size = self.dllm_config.block_sparse_size + topk = self._get_denoise_num() scores = self._get_scores(logits, token_ids) is_masked = dllm_mask == DLLM_MASKED scores = torch.where(is_masked, scores, scores.new_zeros((1, ))) @@ -47,15 +48,52 @@ def low_confidence_static(self, logits: torch.Tensor, token_ids: torch.Tensor, d dllm_mask = torch.where(is_masked, dllm_unmasked, dllm_mask) return dllm_mask.flatten() + def low_confidence_dynamic(self, logits: torch.Tensor, token_ids: torch.Tensor, dllm_mask: torch.Tensor): + """dynamic.""" + block_sparse_size = self.dllm_config.block_sparse_size + threshold = self.dllm_config.confidence_threshold + scores = self._get_scores(logits, token_ids) + is_masked = dllm_mask == DLLM_MASKED + scores = torch.where(is_masked, scores, scores.new_zeros((1, ))) + + scores = scores.view(-1, block_sparse_size) + dllm_mask = dllm_mask.view(-1, block_sparse_size) + _, indices = scores.topk(1, dim=-1) + scores = scores.scatter(-1, indices, threshold) + + is_masked = is_masked.view_as(dllm_mask) + is_masked &= scores >= threshold + dllm_mask[is_masked] = DLLM_UNMASKED + return dllm_mask.flatten() + + def sequential(self, dllm_mask: torch.Tensor): + """sequential.""" + block_sparse_size = self.dllm_config.block_sparse_size + denoise_num = self._get_denoise_num() + dllm_mask = dllm_mask.view(-1, block_sparse_size) + is_masked = dllm_mask == DLLM_MASKED + + # get indices + indices = is_masked.int().argmax(dim=1) + ranges = torch.arange(0, denoise_num, device=indices.device, dtype=indices.dtype) + indices = indices[:, None] + ranges[None, :] + indices = indices % block_sparse_size + + dllm_unmasked = dllm_mask.clone() + dllm_unmasked = dllm_unmasked.scatter(-1, indices, DLLM_UNMASKED) + dllm_mask = torch.where(is_masked, dllm_unmasked, dllm_mask) + + return dllm_mask.flatten() + @record_function('unmasking') def __call__(self, logits: torch.Tensor, input_ids: torch.Tensor, token_ids: torch.Tensor, dllm_mask: torch.Tensor): """call.""" - strategy = self.meta.strategy + strategy = self.dllm_config.unmasking_strategy if strategy is None: return dllm_mask # reshape to [num_blocks, block_sparse_size] - block_sparse_size = self.meta.block_sparse_size + block_sparse_size = self.dllm_config.block_sparse_size dllm_mask = dllm_mask.unflatten(0, (-1, block_sparse_size)) is_same = (dllm_mask == dllm_mask[:, :1]).all(dim=1) @@ -69,6 +107,10 @@ def __call__(self, logits: torch.Tensor, input_ids: torch.Tensor, token_ids: tor token_ids = torch.where(dllm_mask != DLLM_MASKED, input_ids, token_ids) if strategy == 'low_confidence_static': dllm_mask = self.low_confidence_static(logits, token_ids, dllm_mask) + elif strategy == 'low_confidence_dynamic': + dllm_mask = self.low_confidence_dynamic(logits, token_ids, dllm_mask) + elif strategy == 'sequential': + dllm_mask = self.sequential(dllm_mask) else: raise RuntimeError(f'strategy {strategy} not supported.') diff --git a/lmdeploy/pytorch/messages.py b/lmdeploy/pytorch/messages.py index 6c4ec8bc23..711ec2b3a7 100644 --- a/lmdeploy/pytorch/messages.py +++ b/lmdeploy/pytorch/messages.py @@ -8,6 +8,7 @@ from torch import Tensor from lmdeploy.messages import EngineEvent, EventType, GenerationConfig, LogitsProcessor +from lmdeploy.pytorch import consts from lmdeploy.pytorch.disagg.conn.protocol import MigrationRequest from lmdeploy.pytorch.multimodal.data_type import MultiModalInputs from lmdeploy.utils import get_logger @@ -224,9 +225,9 @@ def update_sequence_status(self, seq: 'SchedulerSequence', new_status: MessageSt new_status_map[seq_id] = seq -DLLM_MASKED = 0 -DLLM_UNMASKED = 1 -DLLM_CACHED = 2 +DLLM_MASKED = consts.DLLM_MASKED +DLLM_UNMASKED = consts.DLLM_UNMASKED +DLLM_CACHED = consts.DLLM_CACHED DLLM_MASK_DTYPE = np.uint8 From 42f4582b0420f9443209f7c1952050f42f202715 Mon Sep 17 00:00:00 2001 From: grimoire Date: Mon, 1 Sep 2025 20:48:10 +0800 Subject: [PATCH 06/29] fix multiround stop words --- lmdeploy/cli/cli.py | 1 + lmdeploy/pytorch/engine/engine.py | 6 +++ lmdeploy/pytorch/engine/model_agent.py | 72 ++++++++++++++++++++------ lmdeploy/pytorch/messages.py | 3 ++ 4 files changed, 66 insertions(+), 16 deletions(-) diff --git a/lmdeploy/cli/cli.py b/lmdeploy/cli/cli.py index 916b91e527..d71198791f 100644 --- a/lmdeploy/cli/cli.py +++ b/lmdeploy/cli/cli.py @@ -55,6 +55,7 @@ def add_parser_chat(): ArgumentHelper.adapters(pt_group) ArgumentHelper.device(pt_group) ArgumentHelper.eager_mode(pt_group) + ArgumentHelper.dllm_block_length(pt_group) # common engine args dtype_act = ArgumentHelper.dtype(pt_group) tp_act = ArgumentHelper.tp(pt_group) diff --git a/lmdeploy/pytorch/engine/engine.py b/lmdeploy/pytorch/engine/engine.py index 4c87e025a8..46ff82daa7 100644 --- a/lmdeploy/pytorch/engine/engine.py +++ b/lmdeploy/pytorch/engine/engine.py @@ -980,6 +980,10 @@ def __make_sampling_inputs(seqs: SeqList): block_sparse_size=block_sparse_size) return SamplingInputs.from_sampling_params(seqs, pad_token_id=pad_id) + def __get_input_pos(seqs: SeqList): + pos = [seq.input_pos for seq in seqs] + return torch.tensor(pos) + scheduler = self.scheduler logger.debug(f'Make forward inputs with prefill={prefill}, enable_empty={enable_empty}') @@ -1008,6 +1012,7 @@ def __make_sampling_inputs(seqs: SeqList): num_appendable_ids = __get_num_appendable_ids(running) return_logits = __need_logits(running) dllm_mask = __get_dllm_mask(running) + input_pos = __get_input_pos(running) sync_long_context = inputs.input_ids.numel() > self.cache_config.max_prefill_token_num return dict( @@ -1022,6 +1027,7 @@ def __make_sampling_inputs(seqs: SeqList): is_dummy=False, sync_long_context=sync_long_context, dllm_mask=dllm_mask, + input_pos=input_pos, ) async def _await_forward_event(self, forward_event: asyncio.Event): diff --git a/lmdeploy/pytorch/engine/model_agent.py b/lmdeploy/pytorch/engine/model_agent.py index 76f69fbe1f..3f1aadf528 100644 --- a/lmdeploy/pytorch/engine/model_agent.py +++ b/lmdeploy/pytorch/engine/model_agent.py @@ -261,30 +261,64 @@ def _batch_stopping_criteria_default(token_ids: torch.Tensor, stop_words: torch. return stopped, stop_pos, num_appendable_ids +def _check_stopwords_dllm(token_ids: torch.Tensor, stop_words: torch.Tensor, is_unmasked: torch.Tensor, + stopped: torch.Tensor, stop_pos: torch.Tensor, num_appendable_ids: torch.Tensor, + input_pos: torch.Tensor, inputs: ModelInputs): + num_tokens = token_ids.size(0) + batch_size = num_appendable_ids.size(0) + block_sparse_size = num_tokens // batch_size + + # blocks might contain stop words in prev-round chat + # these stop words should be ignored + kv_seqlens = inputs.history_lengths + inputs.seq_length + ignore_pos = (input_pos - (kv_seqlens - block_sparse_size)).clamp_min(0) + ignore_range = torch.arange(0, block_sparse_size, dtype=ignore_pos.dtype, device=ignore_pos.device) + ignore_mask = (ignore_range[None, :] < ignore_pos[:, None]).flatten() + token_ids = token_ids.clone() + token_ids[ignore_mask] = -1 + + # find stop words + sw_stopped = (token_ids[:, None] == stop_words).any(1) + sw_stopped = sw_stopped.view(batch_size, block_sparse_size) + sw_stop_pos = sw_stopped.int().argmax(1) + + stop_pos = torch.where(stopped, stop_pos, sw_stop_pos) + sw_stopped = sw_stopped.any(dim=1) + sw_stopped = sw_stopped & is_unmasked + stopped = stopped | sw_stopped + + # update num_appendable_ids + one_ids = torch.clamp_max(num_appendable_ids, 0) + num_appendable_ids = torch.where(sw_stopped, one_ids, num_appendable_ids) + + return stopped, stop_pos, num_appendable_ids + + def _batch_stopping_criteria_dllm(token_ids: torch.Tensor, stop_words: torch.Tensor, num_appendable_ids: torch.Tensor, - dllm_mask: torch.Tensor): + dllm_mask: torch.Tensor, input_pos: torch.Tensor, inputs: ModelInputs): """Batched stopping criteria.""" num_tokens = token_ids.size(0) batch_size = num_appendable_ids.size(0) block_sparse_size = num_tokens // batch_size - if block_sparse_size == 1: - return _batch_stopping_criteria_default(token_ids, stop_words, num_appendable_ids) dllm_mask = dllm_mask.view(batch_size, block_sparse_size) is_unmasked = (dllm_mask == consts.DLLM_UNMASKED).all(dim=1) + + # check stop by num_new_tokens num_appendable_ids -= is_unmasked * block_sparse_size stopped = num_appendable_ids <= 0 stop_pos = block_sparse_size - 1 + num_appendable_ids + + # check stop words if stop_words is not None: - sw_stopped = (token_ids[:, None] == stop_words).any(1) - sw_stopped = sw_stopped.view(batch_size, block_sparse_size) - sw_stop_pos = sw_stopped.int().argmax(1) - stop_pos = torch.where(stopped, stop_pos, sw_stop_pos) - sw_stopped = sw_stopped.any(dim=1) - sw_stopped = sw_stopped & is_unmasked - stopped = stopped | sw_stopped - one_ids = torch.clamp_max(num_appendable_ids, 0) - num_appendable_ids = torch.where(sw_stopped, one_ids, num_appendable_ids) + stopped, stop_pos, num_appendable_ids = _check_stopwords_dllm(token_ids, + stop_words, + is_unmasked, + stopped, + stop_pos, + num_appendable_ids, + input_pos=input_pos, + inputs=inputs) return stopped, stop_pos, num_appendable_ids @@ -523,11 +557,14 @@ def _batch_stopping_criteria(self, token_ids: torch.Tensor, stop_words: torch.Tensor, num_appendable_ids: torch.Tensor, - dllm_mask: Optional[torch.Tensor] = None): + dllm_mask: Optional[torch.Tensor] = None, + input_pos: Optional[torch.Tensor] = None, + inputs: Optional[ModelInputs] = None): """Batched stopping criteria.""" if self.model_config.model_paradigm == 'dllm': assert dllm_mask is not None - return _batch_stopping_criteria_dllm(token_ids, stop_words, num_appendable_ids, dllm_mask) + return _batch_stopping_criteria_dllm(token_ids, stop_words, num_appendable_ids, dllm_mask, input_pos, + inputs) return _batch_stopping_criteria_default(token_ids, stop_words, num_appendable_ids) @@ -698,6 +735,7 @@ async def _async_step_background( is_dummy: bool = False, sync_long_context: bool = False, dllm_mask: torch.Tensor = None, + input_pos: torch.Tensor = None, ): """Asyc forward task.""" dist_ctx = get_dist_manager().current_context() @@ -855,7 +893,9 @@ async def __prepare_dp(): stopped, stop_pos, num_appendable_ids = self._batch_stopping_criteria(next_token_ids, sampling_inputs.stop_words, num_appendable_ids, - dllm_mask=dllm_mask) + dllm_mask=dllm_mask, + input_pos=input_pos, + inputs=inputs) if need_broadcast_next: handle.wait() else: @@ -917,7 +957,7 @@ async def _async_loop_background(self, forward_event: asyncio.Event = None): async def _async_loop_inputs_preprocess(self): """Async loop inputs preprocess.""" non_blocking = True - keys = ['inputs', 'sampling_inputs', 'num_appendable_ids', 'dllm_mask'] + keys = ['inputs', 'sampling_inputs', 'num_appendable_ids', 'dllm_mask', 'input_pos'] while True: forward_inputs = await self._pre_in_que.get() diff --git a/lmdeploy/pytorch/messages.py b/lmdeploy/pytorch/messages.py index 711ec2b3a7..15ae47d466 100644 --- a/lmdeploy/pytorch/messages.py +++ b/lmdeploy/pytorch/messages.py @@ -498,6 +498,7 @@ class SchedulerSequence: logical_blocks: LogicalTokenBlocks = field(default_factory=LogicalTokenBlocks) adapter_name: str = None arrive_time: float = 0.0 + input_pos: int = 0 meta: Any = None _status: MessageStatus = field(default=MessageStatus.WAITING, init=False) num_ignored_history: int = 0 @@ -706,6 +707,7 @@ def update_token_ids(self, num_valid = len(token_ids) if mode == UpdateTokenMode.INPUTS: + self.input_pos = self.num_all_ids self._num_token_ids += num_valid self.num_new_tokens = 0 else: @@ -818,6 +820,7 @@ def _update_token_ids_inputs(self, token_ids: np.ndarray, dllm_mask: np.ndarray) self.history_cache.append(token_ids) self.history_dllm_mask.append(dllm_mask) + self.input_pos = self._num_valid_ids self._num_valid_ids = self.num_history_ids + num_tokens self._num_token_ids = len(token_ids) self.num_new_tokens = 0 From 9a68f1aace772098c8b915179d50d95089bff759 Mon Sep 17 00:00:00 2001 From: grimoire Date: Tue, 2 Sep 2025 14:04:27 +0800 Subject: [PATCH 07/29] fix sampling step --- lmdeploy/pytorch/engine/logits_process.py | 15 +++++++++++++-- lmdeploy/pytorch/engine/model_agent.py | 2 +- 2 files changed, 14 insertions(+), 3 deletions(-) diff --git a/lmdeploy/pytorch/engine/logits_process.py b/lmdeploy/pytorch/engine/logits_process.py index 89be71696c..aaed429f3d 100644 --- a/lmdeploy/pytorch/engine/logits_process.py +++ b/lmdeploy/pytorch/engine/logits_process.py @@ -171,6 +171,7 @@ class SamplingInputs: all_ids: Optional[torch.Tensor] = None guided_input_ids: Optional[torch.Tensor] = None num_ignore_eos: torch.Tensor = None + batch_size: int = 0 @classmethod def from_sampling_params(cls, seqs: List[SchedulerSequence], pad_token_id: int = 0): @@ -298,6 +299,7 @@ def __get_bad_words(bad_words): min_top_p=min_top_p, logits_processors=logits_processors, max_num_logprobs=max_num_logprobs, + batch_size=batch_size, ) sampling_input.all_ids = _gather_all_ids(pad_token_id, seqs, sampling_input) @@ -367,11 +369,20 @@ def from_sampling_params(cls, seqs: List[SchedulerSequence], pad_token_id: int = new_resp_formats += [resp] * block_sparse_size out.response_formats = new_resp_formats + out.batch_size *= block_sparse_size + return out - def step(self, next_token_ids: torch.Tensor, **kwargs): + def step(self, next_token_ids: torch.Tensor, dllm_mask: torch.Tensor, **kwargs): """To next step.""" - self.num_ignore_eos = self.num_ignore_eos - 1 + from lmdeploy.pytorch import consts + batch_size = self.batch_size + block_sparse_size = next_token_ids.numel() // batch_size + DLLM_UNMASKED = consts.DLLM_UNMASKED + is_unmasked = (dllm_mask == DLLM_UNMASKED).view(batch_size, -1).all(dim=1, keepdim=True) + num_ignore_eos = self.num_ignore_eos.view(batch_size, -1) + num_ignore_eos = torch.where(is_unmasked, num_ignore_eos - block_sparse_size, num_ignore_eos) + self.num_ignore_eos = num_ignore_eos.flatten() def _apply_custom_logits_processors(batched_logits_processors, all_ids, logits): diff --git a/lmdeploy/pytorch/engine/model_agent.py b/lmdeploy/pytorch/engine/model_agent.py index 3f1aadf528..55b8253084 100644 --- a/lmdeploy/pytorch/engine/model_agent.py +++ b/lmdeploy/pytorch/engine/model_agent.py @@ -769,7 +769,7 @@ def __update_inputs(next_token_ids, model_metas, dllm_mask): inputs.model_metas = model_metas next_token_ids, dllm_mask, step_seqlens = __update_dllm(next_token_ids, dllm_mask, inputs.seq_length) inputs.step(next_token_ids, step_seqlens) - sampling_inputs.step(next_token_ids) + sampling_inputs.step(next_token_ids, dllm_mask=dllm_mask) return next_token_ids, dllm_mask @asynccontextmanager From 0fa2e7ee31f04cfd3a208d5de0b4f35694fa2d7a Mon Sep 17 00:00:00 2001 From: grimoire Date: Tue, 2 Sep 2025 15:36:38 +0800 Subject: [PATCH 08/29] optimize position_ids --- lmdeploy/pytorch/model_inputs.py | 37 +++++++++++++++----------------- 1 file changed, 17 insertions(+), 20 deletions(-) diff --git a/lmdeploy/pytorch/model_inputs.py b/lmdeploy/pytorch/model_inputs.py index d64669cf7c..f3d18112c5 100644 --- a/lmdeploy/pytorch/model_inputs.py +++ b/lmdeploy/pytorch/model_inputs.py @@ -369,7 +369,6 @@ def new( model_config: ModelConfig, kv_caches: List = None, kv_quant_policy: Literal[0, 4, 8] = 0, - build_ctx: 'BuildModelContext' = None, ): """Build step context. @@ -391,7 +390,7 @@ def new( inputs.vision_inputs.get_inputs(history_seqlens, q_seqlens) # position ids - attention_mask, position_ids = cls.get_mask_and_position_ids(inputs, build_ctx) + attention_mask, position_ids = cls.get_mask_and_position_ids(inputs) position_ids = position_ids[None] # [num_tokens] -> [1, num_tokens] q_start_loc = q_seqlens.cumsum(0) - q_seqlens @@ -434,32 +433,31 @@ def new( return ret @classmethod - def get_mask_and_position_ids(cls, inputs: ModelInputs, build_ctx: 'BuildModelContext'): + def get_mask_and_position_ids(cls, inputs: ModelInputs): """Get position ids.""" q_seqlens = inputs.seq_length history_seqlens = inputs.history_lengths - model_paradigm = build_ctx.model_paradigm - is_decoding = inputs.is_decoding + max_q_seqlen = inputs.max_q_seqlen # decoding - if is_decoding: - if model_paradigm == 'dllm': - block_sparse_size = build_ctx.block_sparse_size - attention_mask = None - ranges = torch.arange(0, block_sparse_size, device=q_seqlens.device) - position_ids = history_seqlens[:, None] + ranges[None, :] - position_ids = position_ids.flatten() - return attention_mask, position_ids - else: - attention_mask = torch.ones_like(q_seqlens)[:, None] - position_ids = history_seqlens.unsqueeze(-1).clone() - position_ids = position_ids.flatten() - return attention_mask, position_ids + if max_q_seqlen == 1: + attention_mask = torch.ones_like(q_seqlens)[:, None] + position_ids = history_seqlens.unsqueeze(-1).clone() + position_ids = position_ids.flatten() + return attention_mask, position_ids num_tokens = inputs.input_ids.numel() - max_q_seqlen = inputs.max_q_seqlen + batch_size = inputs.seq_length.numel() device = q_seqlens.device + # batch with same seqlens + if max_q_seqlen * batch_size == num_tokens: + attention_mask = None + ranges = torch.arange(0, max_q_seqlen, device=device) + position_ids = history_seqlens[:, None] + ranges[None, :] + position_ids = position_ids.flatten() + return attention_mask, position_ids + # get mask mask_range = torch.arange(max_q_seqlen, device=device)[None, :] attention_mask = (mask_range < q_seqlens[:, None]).long() @@ -502,7 +500,6 @@ def build_context( model_config, kv_caches, kv_quant_policy, - build_ctx=self.build_ctx, ) def set_context(self, ctx: StepContext): From 85255d26a2a4bed941bf5f6b352f05e45ce6b75a Mon Sep 17 00:00:00 2001 From: grimoire Date: Tue, 2 Sep 2025 16:23:56 +0800 Subject: [PATCH 09/29] fix long context --- lmdeploy/pytorch/engine/model_agent.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/lmdeploy/pytorch/engine/model_agent.py b/lmdeploy/pytorch/engine/model_agent.py index 55b8253084..9ded7b045e 100644 --- a/lmdeploy/pytorch/engine/model_agent.py +++ b/lmdeploy/pytorch/engine/model_agent.py @@ -513,13 +513,13 @@ def _postprocess_forward_output_dllm(self, output: dict, inputs: ModelInputs, is """Post process for dllm.""" block_sparse_size = self.misc_config.block_sparse_size hidden_states = output['hidden_states'] - seq_length = inputs.seq_length if is_long_context: if not return_logits: hidden_states = hidden_states[:, -block_sparse_size:] else: hidden_states = hidden_states.to('cuda') else: + seq_length = inputs.seq_length is_decoding = seq_length.numel() * block_sparse_size == hidden_states.size(1) if not return_logits and not is_decoding: hidden_states = self._slice_outs(hidden_states[0], seq_length)[None] @@ -530,8 +530,8 @@ def _postprocess_forward_output_default(self, output: dict, inputs: ModelInputs, return_logits: bool): """Post process forward output default.""" hidden_states = output['hidden_states'] - seq_length = inputs.seq_length if not is_long_context: + seq_length = inputs.seq_length is_decoding = seq_length.numel() != hidden_states.size(1) if not return_logits and not is_decoding: hidden_states = self._slice_outs(hidden_states[0], seq_length)[None] @@ -628,6 +628,8 @@ async def __long_context_single_forward(new_inputs, max_seqlen: int): tmp_out['hidden_states'] = output_gather.get_output() return tmp_out + origin_inputs = inputs + # make long context inputs is_long_context = inputs.input_ids.numel() > max_prefill_token_num and not inputs.is_decoding max_seqlen = 0 @@ -653,7 +655,7 @@ async def __long_context_single_forward(new_inputs, max_seqlen: int): else: ret = await __long_context_single_forward(inputs, max_seqlen) - ret = self._postprocess_forward_output(ret, inputs, is_long_context, return_logits) + ret = self._postprocess_forward_output(ret, origin_inputs, is_long_context, return_logits) # compute dummy loop if dummy_loop > 0: From b65afc57a559b50f2c18e91b234ab906c1fdc541 Mon Sep 17 00:00:00 2001 From: grimoire Date: Tue, 2 Sep 2025 16:44:18 +0800 Subject: [PATCH 10/29] fix vlm --- lmdeploy/pytorch/messages.py | 28 +++++++++++++--------------- 1 file changed, 13 insertions(+), 15 deletions(-) diff --git a/lmdeploy/pytorch/messages.py b/lmdeploy/pytorch/messages.py index 15ae47d466..a2b6bf2a17 100644 --- a/lmdeploy/pytorch/messages.py +++ b/lmdeploy/pytorch/messages.py @@ -685,7 +685,7 @@ def _update_multimodals(self, old_num_history_ids: int, multimodals: MultiModalI if multimodals is None: self._num_cross = 0 return - multimodals = HistoryMultiModals.update_multimodals(multimodals, self.num_all_ids) + multimodals = HistoryMultiModals.update_multimodals(multimodals, self.num_valid_ids) self.history_multimodals.add_inputs(multimodals) # for mllama @@ -699,7 +699,12 @@ def update_token_ids(self, mode: UpdateTokenMode = UpdateTokenMode.INPUTS, **kwargs): """Update token ids, old token ids will be added to history.""" - old_num_history_ids = self._num_history_ids + # update history image nums + self._update_embeddings(embeddings) + + # update multimodals + self._update_multimodals(self._num_history_ids, multimodals) + self.arrive_time = time.perf_counter() token_ids = _to_ndarray(token_ids) @@ -718,12 +723,6 @@ def update_token_ids(self, self.history_cache.append(token_ids) - # update history image nums - self._update_embeddings(embeddings) - - # update multimodals - self._update_multimodals(old_num_history_ids, multimodals) - if model_meta is not None: self.model_meta = model_meta @@ -880,7 +879,12 @@ def update_token_ids(self, mode: UpdateTokenMode = UpdateTokenMode.INPUTS, **kwargs): """Update token ids, old token ids will be added to history.""" - old_num_history_ids = self._num_history_ids + # update history image nums + self._update_embeddings(embeddings) + + # update multimodals + self._update_multimodals(self._num_history_ids, multimodals) + self.arrive_time = time.perf_counter() token_ids: np.ndarray = _to_ndarray(token_ids) @@ -895,12 +899,6 @@ def update_token_ids(self, else: self._update_token_ids_decode(token_ids, dllm_mask) - # update history image nums - self._update_embeddings(embeddings) - - # update multimodals - self._update_multimodals(old_num_history_ids, multimodals) - if model_meta is not None: self.model_meta = model_meta From da2f4031cc1e8a9b2a2698ee78fa3a2972fea275 Mon Sep 17 00:00:00 2001 From: grimoire Date: Tue, 2 Sep 2025 18:02:41 +0800 Subject: [PATCH 11/29] fix stopping --- lmdeploy/pytorch/messages.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/lmdeploy/pytorch/messages.py b/lmdeploy/pytorch/messages.py index a2b6bf2a17..64bc1159a8 100644 --- a/lmdeploy/pytorch/messages.py +++ b/lmdeploy/pytorch/messages.py @@ -712,7 +712,7 @@ def update_token_ids(self, num_valid = len(token_ids) if mode == UpdateTokenMode.INPUTS: - self.input_pos = self.num_all_ids + self.input_pos = self.num_all_ids + len(token_ids) self._num_token_ids += num_valid self.num_new_tokens = 0 else: @@ -819,7 +819,7 @@ def _update_token_ids_inputs(self, token_ids: np.ndarray, dllm_mask: np.ndarray) self.history_cache.append(token_ids) self.history_dllm_mask.append(dllm_mask) - self.input_pos = self._num_valid_ids + self.input_pos = self._num_valid_ids + len(token_ids) self._num_valid_ids = self.num_history_ids + num_tokens self._num_token_ids = len(token_ids) self.num_new_tokens = 0 From e6b5bdd4b11a08903dde2beb6879c14d03c87196 Mon Sep 17 00:00:00 2001 From: grimoire Date: Tue, 2 Sep 2025 18:49:44 +0800 Subject: [PATCH 12/29] move args into logitsprocessor --- lmdeploy/pytorch/engine/logits_process.py | 12 +++++------- lmdeploy/pytorch/engine/model_agent.py | 6 +----- 2 files changed, 6 insertions(+), 12 deletions(-) diff --git a/lmdeploy/pytorch/engine/logits_process.py b/lmdeploy/pytorch/engine/logits_process.py index aaed429f3d..e131e94e3b 100644 --- a/lmdeploy/pytorch/engine/logits_process.py +++ b/lmdeploy/pytorch/engine/logits_process.py @@ -399,12 +399,10 @@ class FusedLogitsProcessor: def __init__(self, sampling_inputs: SamplingInputs, - ignore_eos: torch.Tensor, tokenizer: Optional[Tokenizer] = None, sampling_vocab_size: Optional[int] = None, logprobs_mode: Optional[str] = None): self.sampling_inputs: SamplingInputs = sampling_inputs - self.ignore_eos = ignore_eos self.tokenizer = tokenizer self.sampling_vocab_size = sampling_vocab_size self.logprobs_mode = logprobs_mode @@ -415,12 +413,9 @@ async def _wait_stream_once(self): if not stream.query(): await asyncio.sleep(0) - async def __call__(self, all_ids: torch.LongTensor, guided_input_ids: torch.LongTensor, - scores: torch.FloatTensor) -> torch.FloatTensor: + async def __call__(self, scores: torch.FloatTensor) -> torch.FloatTensor: r""" Args: - all_ids (torch.LongTensor): All the token ids. - guided_input_ids (torch.LongTensor): Guided prompt ids. scores (torch.FloatTensor): Prediction scores of a language modeling head. These can be logits for each vocabulary when not using @@ -446,6 +441,8 @@ async def __call__(self, all_ids: torch.LongTensor, guided_input_ids: torch.Long logprobs = None sampling_inputs = self.sampling_inputs + all_ids = sampling_inputs.all_ids + guided_input_ids = sampling_inputs.guided_input_ids custom_logits_processors = self.sampling_inputs.logits_processors if any(custom_logits_processors): @@ -467,8 +464,9 @@ async def __call__(self, all_ids: torch.LongTensor, guided_input_ids: torch.Long stop_words = sampling_inputs.stop_words if stop_words is not None: + ignore_eos = sampling_inputs.num_ignore_eos > 0 stop_mask = sampling_inputs.stop_mask - stop_mask = torch.where(self.ignore_eos[:, None], stop_mask, False) + stop_mask = torch.where(ignore_eos[:, None], stop_mask, False) scores = _process_bad_words_(scores, stop_words, stop_mask) if guided_input_ids is not None: diff --git a/lmdeploy/pytorch/engine/model_agent.py b/lmdeploy/pytorch/engine/model_agent.py index 9ded7b045e..93cf787bd5 100644 --- a/lmdeploy/pytorch/engine/model_agent.py +++ b/lmdeploy/pytorch/engine/model_agent.py @@ -678,16 +678,12 @@ async def async_sampling_logits(self, logits: torch.Tensor, sampling_inputs: Sam # record function does not support async function # so we can not decorate it on async_sampling_logits with record_function('sampling_logits'): - all_ids = sampling_inputs.all_ids - guided_input_ids = sampling_inputs.guided_input_ids - ignore_eos = sampling_inputs.num_ignore_eos > 0 logits_processor = FusedLogitsProcessor(sampling_inputs, - ignore_eos, self.tokenizer, sampling_vocab_size=self.sampling_vocab_size, logprobs_mode=self.misc_config.logprobs_mode) origin_logits = logits - logits, raw_logprobs = await logits_processor(all_ids, guided_input_ids, origin_logits) + logits, raw_logprobs = await logits_processor(origin_logits) next_token_ids = logits_processor.sampling(logits) logprobs = logits_processor.compute_logprobs(raw_logprobs, next_token_ids) if logprobs is not None: From 2b0e607230011fda7a42fff871e6206342439324 Mon Sep 17 00:00:00 2001 From: grimoire Date: Wed, 3 Sep 2025 12:11:36 +0800 Subject: [PATCH 13/29] rename --- lmdeploy/pytorch/engine/engine.py | 8 ++-- lmdeploy/pytorch/engine/model_agent.py | 66 ++++++++++++-------------- lmdeploy/pytorch/messages.py | 6 +-- 3 files changed, 38 insertions(+), 42 deletions(-) diff --git a/lmdeploy/pytorch/engine/engine.py b/lmdeploy/pytorch/engine/engine.py index 46ff82daa7..3efc672903 100644 --- a/lmdeploy/pytorch/engine/engine.py +++ b/lmdeploy/pytorch/engine/engine.py @@ -980,8 +980,8 @@ def __make_sampling_inputs(seqs: SeqList): block_sparse_size=block_sparse_size) return SamplingInputs.from_sampling_params(seqs, pad_token_id=pad_id) - def __get_input_pos(seqs: SeqList): - pos = [seq.input_pos for seq in seqs] + def __get_output_start_pos(seqs: SeqList): + pos = [seq.output_start_pos for seq in seqs] return torch.tensor(pos) scheduler = self.scheduler @@ -1012,7 +1012,7 @@ def __get_input_pos(seqs: SeqList): num_appendable_ids = __get_num_appendable_ids(running) return_logits = __need_logits(running) dllm_mask = __get_dllm_mask(running) - input_pos = __get_input_pos(running) + output_start_pos = __get_output_start_pos(running) sync_long_context = inputs.input_ids.numel() > self.cache_config.max_prefill_token_num return dict( @@ -1027,7 +1027,7 @@ def __get_input_pos(seqs: SeqList): is_dummy=False, sync_long_context=sync_long_context, dllm_mask=dllm_mask, - input_pos=input_pos, + output_start_pos=output_start_pos, ) async def _await_forward_event(self, forward_event: asyncio.Event): diff --git a/lmdeploy/pytorch/engine/model_agent.py b/lmdeploy/pytorch/engine/model_agent.py index 93cf787bd5..3684032dd7 100644 --- a/lmdeploy/pytorch/engine/model_agent.py +++ b/lmdeploy/pytorch/engine/model_agent.py @@ -263,7 +263,7 @@ def _batch_stopping_criteria_default(token_ids: torch.Tensor, stop_words: torch. def _check_stopwords_dllm(token_ids: torch.Tensor, stop_words: torch.Tensor, is_unmasked: torch.Tensor, stopped: torch.Tensor, stop_pos: torch.Tensor, num_appendable_ids: torch.Tensor, - input_pos: torch.Tensor, inputs: ModelInputs): + output_start_pos: torch.Tensor, inputs: ModelInputs): num_tokens = token_ids.size(0) batch_size = num_appendable_ids.size(0) block_sparse_size = num_tokens // batch_size @@ -271,7 +271,7 @@ def _check_stopwords_dllm(token_ids: torch.Tensor, stop_words: torch.Tensor, is_ # blocks might contain stop words in prev-round chat # these stop words should be ignored kv_seqlens = inputs.history_lengths + inputs.seq_length - ignore_pos = (input_pos - (kv_seqlens - block_sparse_size)).clamp_min(0) + ignore_pos = (output_start_pos - (kv_seqlens - block_sparse_size)).clamp_min(0) ignore_range = torch.arange(0, block_sparse_size, dtype=ignore_pos.dtype, device=ignore_pos.device) ignore_mask = (ignore_range[None, :] < ignore_pos[:, None]).flatten() token_ids = token_ids.clone() @@ -295,7 +295,7 @@ def _check_stopwords_dllm(token_ids: torch.Tensor, stop_words: torch.Tensor, is_ def _batch_stopping_criteria_dllm(token_ids: torch.Tensor, stop_words: torch.Tensor, num_appendable_ids: torch.Tensor, - dllm_mask: torch.Tensor, input_pos: torch.Tensor, inputs: ModelInputs): + dllm_mask: torch.Tensor, output_start_pos: torch.Tensor, inputs: ModelInputs): """Batched stopping criteria.""" num_tokens = token_ids.size(0) batch_size = num_appendable_ids.size(0) @@ -317,7 +317,7 @@ def _batch_stopping_criteria_dllm(token_ids: torch.Tensor, stop_words: torch.Ten stopped, stop_pos, num_appendable_ids, - input_pos=input_pos, + output_start_pos=output_start_pos, inputs=inputs) return stopped, stop_pos, num_appendable_ids @@ -558,12 +558,12 @@ def _batch_stopping_criteria(self, stop_words: torch.Tensor, num_appendable_ids: torch.Tensor, dllm_mask: Optional[torch.Tensor] = None, - input_pos: Optional[torch.Tensor] = None, + output_start_pos: Optional[torch.Tensor] = None, inputs: Optional[ModelInputs] = None): """Batched stopping criteria.""" if self.model_config.model_paradigm == 'dllm': assert dllm_mask is not None - return _batch_stopping_criteria_dllm(token_ids, stop_words, num_appendable_ids, dllm_mask, input_pos, + return _batch_stopping_criteria_dllm(token_ids, stop_words, num_appendable_ids, dllm_mask, output_start_pos, inputs) return _batch_stopping_criteria_default(token_ids, stop_words, num_appendable_ids) @@ -700,17 +700,18 @@ def _push_output(self, output: BatchedOutputs): event.record() self._out_que.put_nowait((output, event)) - def _broadcast_next_token(self, next_token_ids: torch.Tensor, dist_ctx: DistContext = None): + @contextmanager + def _broadcast_next_token(self, next_token_ids: torch.Tensor, dist_ctx: DistContext = None, enable: bool = True): + if not enable: + yield + return + if dist_ctx is None: dist_ctx = get_dist_manager().current_context() - if self.cache_config.role == EngineRole.Decode: - next_token_ids = next_token_ids.cpu() - tp_cpu_group = dist_ctx.tp_cpu_group - handle = dist.all_reduce(next_token_ids, op=dist.ReduceOp.SUM, group=tp_cpu_group, async_op=True) - else: - tp_gpu_group = dist_ctx.tp_gpu_group - handle = dist.broadcast(next_token_ids, src=0, group=tp_gpu_group, async_op=True) - return handle + tp_gpu_group = dist_ctx.tp_gpu_group + handle = dist.broadcast(next_token_ids, src=0, group=tp_gpu_group, async_op=True) + yield + handle.wait() def dllm_unmasking(self, inputs: ModelInputs, logits: torch.Tensor, next_token_ids: torch.LongTensor, dllm_mask: torch.Tensor): @@ -733,7 +734,7 @@ async def _async_step_background( is_dummy: bool = False, sync_long_context: bool = False, dllm_mask: torch.Tensor = None, - input_pos: torch.Tensor = None, + output_start_pos: torch.Tensor = None, ): """Asyc forward task.""" dist_ctx = get_dist_manager().current_context() @@ -879,23 +880,20 @@ async def __prepare_dp(): # sampling next_token_ids, logprobs = await self.async_sampling_logits(last_logits, sampling_inputs, inputs) - # broadcast next token for TP > 1 - if need_broadcast_next: + with self._broadcast_next_token(next_token_ids, dist_ctx, enable=need_broadcast_next): logger.debug(f' rank[{rank}]: synchornize token ids [{idx}]') - handle = self._broadcast_next_token(next_token_ids, dist_ctx) - - # unmasking - dllm_mask, next_token_ids = self.dllm_unmasking(inputs, last_logits, next_token_ids, dllm_mask) - # stopping criteria - stopped, stop_pos, num_appendable_ids = self._batch_stopping_criteria(next_token_ids, - sampling_inputs.stop_words, - num_appendable_ids, - dllm_mask=dllm_mask, - input_pos=input_pos, - inputs=inputs) - if need_broadcast_next: - handle.wait() + # unmasking + dllm_mask, next_token_ids = self.dllm_unmasking(inputs, last_logits, next_token_ids, dllm_mask) + + # stopping criteria + stopped, stop_pos, num_appendable_ids = self._batch_stopping_criteria( + next_token_ids, + sampling_inputs.stop_words, + num_appendable_ids, + dllm_mask=dllm_mask, + output_start_pos=output_start_pos, + inputs=inputs) else: # Avoid adding the ADInplaceOrView dispatch key to `next_token_ids`, # as it can trigger recompilation on different ranks when using torch.compile. @@ -904,10 +902,8 @@ async def __prepare_dp(): logprobs = None # broadcast next token for TP > 1 - if need_broadcast_next: + with self._broadcast_next_token(next_token_ids, dist_ctx, enable=need_broadcast_next): logger.debug(f' rank[{rank}]: synchornize token ids [{idx}]') - handle = self._broadcast_next_token(next_token_ids, dist_ctx) - handle.wait() # unmasking dllm_mask, next_token_ids = self.dllm_unmasking(inputs, last_logits, next_token_ids, dllm_mask) @@ -955,7 +951,7 @@ async def _async_loop_background(self, forward_event: asyncio.Event = None): async def _async_loop_inputs_preprocess(self): """Async loop inputs preprocess.""" non_blocking = True - keys = ['inputs', 'sampling_inputs', 'num_appendable_ids', 'dllm_mask', 'input_pos'] + keys = ['inputs', 'sampling_inputs', 'num_appendable_ids', 'dllm_mask', 'output_start_pos'] while True: forward_inputs = await self._pre_in_que.get() diff --git a/lmdeploy/pytorch/messages.py b/lmdeploy/pytorch/messages.py index 64bc1159a8..118d42a704 100644 --- a/lmdeploy/pytorch/messages.py +++ b/lmdeploy/pytorch/messages.py @@ -498,7 +498,7 @@ class SchedulerSequence: logical_blocks: LogicalTokenBlocks = field(default_factory=LogicalTokenBlocks) adapter_name: str = None arrive_time: float = 0.0 - input_pos: int = 0 + output_start_pos: int = 0 meta: Any = None _status: MessageStatus = field(default=MessageStatus.WAITING, init=False) num_ignored_history: int = 0 @@ -712,7 +712,7 @@ def update_token_ids(self, num_valid = len(token_ids) if mode == UpdateTokenMode.INPUTS: - self.input_pos = self.num_all_ids + len(token_ids) + self.output_start_pos = self.num_all_ids + len(token_ids) self._num_token_ids += num_valid self.num_new_tokens = 0 else: @@ -819,7 +819,7 @@ def _update_token_ids_inputs(self, token_ids: np.ndarray, dllm_mask: np.ndarray) self.history_cache.append(token_ids) self.history_dllm_mask.append(dllm_mask) - self.input_pos = self._num_valid_ids + len(token_ids) + self.output_start_pos = self._num_valid_ids + len(token_ids) self._num_valid_ids = self.num_history_ids + num_tokens self._num_token_ids = len(token_ids) self.num_new_tokens = 0 From a660a4324c19e249113d60f67d7a89af23040c52 Mon Sep 17 00:00:00 2001 From: grimoire Date: Wed, 3 Sep 2025 12:44:46 +0800 Subject: [PATCH 14/29] fix pd --- lmdeploy/pytorch/paging/scheduler.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/lmdeploy/pytorch/paging/scheduler.py b/lmdeploy/pytorch/paging/scheduler.py index ef542974b6..8144ac52a1 100644 --- a/lmdeploy/pytorch/paging/scheduler.py +++ b/lmdeploy/pytorch/paging/scheduler.py @@ -291,7 +291,7 @@ def __evict_for_seq(seq: SchedulerSequence, num_required_blocks: int): def schedule(self, is_prefill: bool, prealloc_size: int = 0): """Schedule inputs for next steps.""" if is_prefill: - output = self._schedule_prefill(prealloc_size) + output = self._schedule_prefill(0) else: output = self._schedule_decoding(prealloc_size) running, swap_in_map, swap_out_map, copy_map = output From b23d96258edfac00f0b08a58f295d833896b8d40 Mon Sep 17 00:00:00 2001 From: grimoire Date: Wed, 3 Sep 2025 17:25:18 +0800 Subject: [PATCH 15/29] rename --- .../pytorch/backends/cuda/graph_runner.py | 4 +- lmdeploy/pytorch/config.py | 8 ++-- lmdeploy/pytorch/engine/engine.py | 26 +++++------ lmdeploy/pytorch/engine/executor/base.py | 2 +- lmdeploy/pytorch/engine/logits_process.py | 19 ++++---- lmdeploy/pytorch/engine/model_agent.py | 44 +++++++++---------- lmdeploy/pytorch/engine/unmasking.py | 32 +++++++------- lmdeploy/pytorch/messages.py | 38 +++++++++------- lmdeploy/pytorch/model_inputs.py | 8 ++-- lmdeploy/pytorch/models/sdar.py | 6 +-- 10 files changed, 95 insertions(+), 92 deletions(-) diff --git a/lmdeploy/pytorch/backends/cuda/graph_runner.py b/lmdeploy/pytorch/backends/cuda/graph_runner.py index a866f584b5..dcbeb28d38 100644 --- a/lmdeploy/pytorch/backends/cuda/graph_runner.py +++ b/lmdeploy/pytorch/backends/cuda/graph_runner.py @@ -195,8 +195,8 @@ def _get_max_tokens(self, graph_key: tuple): if model_paradigm == 'dllm': step_mgr = get_step_ctx_manager() build_ctx = step_mgr.build_ctx - block_sparse_size = build_ctx.block_sparse_size - return max_batches * block_sparse_size + dllm_block_length = build_ctx.dllm_config.dllm_block_length + return max_batches * dllm_block_length return max_batches def __call__(self, **kwargs): diff --git a/lmdeploy/pytorch/config.py b/lmdeploy/pytorch/config.py index 1326c6e4b2..8c36ec2782 100644 --- a/lmdeploy/pytorch/config.py +++ b/lmdeploy/pytorch/config.py @@ -289,7 +289,7 @@ def from_hf_config(cls, @dataclass class DLLMConfig: - block_sparse_size: int = 1 + dllm_block_length: int = 1 unmasking_strategy: str = 'low_confidence_dynamic' denoising_steps: int = None confidence_threshold: float = 0.85 @@ -304,7 +304,7 @@ class MiscConfig: hf_overrides: Dict[str, Any] = None disable_vision_encoder: bool = False logprobs_mode: str = None - block_sparse_size: int = 1 + dllm_block_length: int = 1 dllm_config: DLLMConfig = None @classmethod @@ -313,7 +313,7 @@ def from_engine_config(cls, engine_config: PytorchEngineConfig): denoising_steps = engine_config.dllm_denoising_steps if denoising_steps is None: denoising_steps = engine_config.dllm_block_length // 2 - dllm_config = DLLMConfig(block_sparse_size=engine_config.dllm_block_length, + dllm_config = DLLMConfig(dllm_block_length=engine_config.dllm_block_length, unmasking_strategy=engine_config.dllm_unmasking_strategy, denoising_steps=denoising_steps, confidence_threshold=engine_config.dllm_confidence_threshold) @@ -323,7 +323,7 @@ def from_engine_config(cls, engine_config: PytorchEngineConfig): model_format=engine_config.model_format, hf_overrides=engine_config.hf_overrides, disable_vision_encoder=engine_config.disable_vision_encoder, - block_sparse_size=engine_config.dllm_block_length, + dllm_block_length=engine_config.dllm_block_length, logprobs_mode=engine_config.logprobs_mode, dllm_config=dllm_config) return misc_config diff --git a/lmdeploy/pytorch/engine/engine.py b/lmdeploy/pytorch/engine/engine.py index 3efc672903..754830f0b3 100644 --- a/lmdeploy/pytorch/engine/engine.py +++ b/lmdeploy/pytorch/engine/engine.py @@ -144,7 +144,7 @@ def _build_seq_meta(cache_config: CacheConfig, model_config: ModelConfig, engine seq_meta = SequenceMeta(cache_config.block_size, model_paradigm=model_config.model_paradigm, - block_sparse_size=engine_config.dllm_block_length, + dllm_block_length=engine_config.dllm_block_length, dllm_mask_token=model_config.dllm_mask_token) return seq_meta @@ -396,6 +396,7 @@ def __init__(self, self.cache_config = cache_config self.backend_config = backend_config self.dist_config = dist_config + self.misc_config = self.executor.misc_config self.max_session_len = self._get_max_session_len() self.engine_config.num_cpu_blocks = self.cache_config.num_cpu_blocks self.engine_config.num_gpu_blocks = self.cache_config.num_gpu_blocks @@ -816,9 +817,9 @@ def _update_running_default(self, running: SeqList, next_token_ids: torch.Tensor def _update_running_dllm(self, running: SeqList, next_token_ids: torch.Tensor, dllm_mask: torch.Tensor, stopped: List[bool], model_metas: List[Any], is_decoding: bool, stop_pos: torch.Tensor): - block_sparse_size = self.seq_meta.block_sparse_size - next_token_ids = next_token_ids.view(-1, block_sparse_size).numpy() - dllm_mask = dllm_mask.view(-1, block_sparse_size).numpy() + batch_size = len(running) + next_token_ids = next_token_ids.view(batch_size, -1).numpy() + dllm_mask = dllm_mask.view(batch_size, -1).numpy() stop_pos = stop_pos.tolist() update_mode = UpdateTokenMode.DECODE if is_decoding else UpdateTokenMode.PREFILL for idx, token in enumerate(next_token_ids): @@ -928,9 +929,9 @@ def __get_prealloc(prefill_interval: int): model_paradigm = self.model_config.model_paradigm if model_paradigm == 'dllm': block_size = self.cache_config.block_size - block_sparse_size = self.seq_meta.block_sparse_size - num_blocks = min(prefill_interval // 2, block_size // block_sparse_size) - return num_blocks * block_sparse_size + dllm_block_length = self.misc_config.dllm_config.dllm_block_length + num_blocks = min(prefill_interval // 2, block_size // dllm_block_length) + return num_blocks * dllm_block_length else: return prefill_interval @@ -940,8 +941,8 @@ def __get_num_loops(is_prefill: bool, prefill_interval: int): return 1 if model_paradigm == 'dllm': block_size = self.cache_config.block_size - block_sparse_size = self.seq_meta.block_sparse_size - max_num_loops = block_size // block_sparse_size * 2 + dllm_block_length = self.misc_config.dllm_config.dllm_block_length + max_num_loops = block_size // dllm_block_length * 2 num_loops = min(prefill_interval, max_num_loops) return num_loops else: @@ -952,8 +953,8 @@ def __get_num_appendable_ids(seqs: SeqList): num_appendable = [seq.sampling_param.max_new_tokens - seq.num_new_tokens for seq in seqs] num_appendable = torch.tensor(num_appendable) if self.model_config.model_paradigm == 'dllm': - block_sparse_size = self.seq_meta.block_sparse_size - remain = [seq.num_valid_ids % block_sparse_size for seq in seqs] + dllm_block_length = self.misc_config.dllm_config.dllm_block_length + remain = [seq.num_valid_ids % dllm_block_length for seq in seqs] num_appendable += torch.tensor(remain) return num_appendable @@ -974,10 +975,9 @@ def __make_sampling_inputs(seqs: SeqList): pad_id = 0 if pad_id is None else pad_id if self.model_config.model_paradigm == 'dllm': from .logits_process import SamplingInputsDLLM - block_sparse_size = self.seq_meta.block_sparse_size return SamplingInputsDLLM.from_sampling_params(seqs, pad_token_id=pad_id, - block_sparse_size=block_sparse_size) + dllm_config=self.misc_config.dllm_config) return SamplingInputs.from_sampling_params(seqs, pad_token_id=pad_id) def __get_output_start_pos(seqs: SeqList): diff --git a/lmdeploy/pytorch/engine/executor/base.py b/lmdeploy/pytorch/engine/executor/base.py index 647504c01b..b1fc8f71bf 100644 --- a/lmdeploy/pytorch/engine/executor/base.py +++ b/lmdeploy/pytorch/engine/executor/base.py @@ -36,7 +36,7 @@ def __init__(self, self.cache_config = cache_config self.backend_config = backend_config self.dist_config = dist_config - self.misc_config = misc_config, + self.misc_config = misc_config self.tokenizer = tokenizer self.dp = dist_config.dp self.tp = dist_config.tp diff --git a/lmdeploy/pytorch/engine/logits_process.py b/lmdeploy/pytorch/engine/logits_process.py index e131e94e3b..a1ace91884 100644 --- a/lmdeploy/pytorch/engine/logits_process.py +++ b/lmdeploy/pytorch/engine/logits_process.py @@ -7,6 +7,7 @@ import torch from lmdeploy.messages import LogitsProcessor +from lmdeploy.pytorch.config import DLLMConfig from lmdeploy.tokenizer import Tokenizer from ..messages import SchedulerSequence @@ -332,11 +333,11 @@ def step(self, next_token_ids: torch.Tensor, **kwargs): class SamplingInputsDLLM(SamplingInputs): @classmethod - def from_sampling_params(cls, seqs: List[SchedulerSequence], pad_token_id: int = 0, block_sparse_size: int = 1): + def from_sampling_params(cls, seqs: List[SchedulerSequence], pad_token_id: int = 0, dllm_config: DLLMConfig = None): """From samplingg params.""" out = super().from_sampling_params(seqs, pad_token_id) - if block_sparse_size == 1: - return out + assert dllm_config is not None + dllm_block_length = dllm_config.dllm_block_length # repeat tensor update_attr_names = [ @@ -359,17 +360,17 @@ def from_sampling_params(cls, seqs: List[SchedulerSequence], pad_token_id: int = attr = getattr(out, name) if attr is None: continue - repeats = (block_sparse_size, ) + (1, ) * (attr.dim()) + repeats = (dllm_block_length, ) + (1, ) * (attr.dim()) attr = attr[None].repeat(*repeats).flatten(0, 1) setattr(out, name, attr) if len(out.response_formats) > 0: new_resp_formats = [] for resp in out.response_formats: - new_resp_formats += [resp] * block_sparse_size - out.response_formats = new_resp_formats + new_resp_formats += [resp] * dllm_block_length + out.response_formats = tuple(new_resp_formats) - out.batch_size *= block_sparse_size + out.batch_size *= dllm_block_length return out @@ -377,11 +378,11 @@ def step(self, next_token_ids: torch.Tensor, dllm_mask: torch.Tensor, **kwargs): """To next step.""" from lmdeploy.pytorch import consts batch_size = self.batch_size - block_sparse_size = next_token_ids.numel() // batch_size + dllm_block_size = next_token_ids.numel() // batch_size DLLM_UNMASKED = consts.DLLM_UNMASKED is_unmasked = (dllm_mask == DLLM_UNMASKED).view(batch_size, -1).all(dim=1, keepdim=True) num_ignore_eos = self.num_ignore_eos.view(batch_size, -1) - num_ignore_eos = torch.where(is_unmasked, num_ignore_eos - block_sparse_size, num_ignore_eos) + num_ignore_eos = torch.where(is_unmasked, num_ignore_eos - dllm_block_size, num_ignore_eos) self.num_ignore_eos = num_ignore_eos.flatten() diff --git a/lmdeploy/pytorch/engine/model_agent.py b/lmdeploy/pytorch/engine/model_agent.py index 3684032dd7..d177ea4f10 100644 --- a/lmdeploy/pytorch/engine/model_agent.py +++ b/lmdeploy/pytorch/engine/model_agent.py @@ -266,20 +266,20 @@ def _check_stopwords_dllm(token_ids: torch.Tensor, stop_words: torch.Tensor, is_ output_start_pos: torch.Tensor, inputs: ModelInputs): num_tokens = token_ids.size(0) batch_size = num_appendable_ids.size(0) - block_sparse_size = num_tokens // batch_size + block_size = num_tokens // batch_size # blocks might contain stop words in prev-round chat # these stop words should be ignored kv_seqlens = inputs.history_lengths + inputs.seq_length - ignore_pos = (output_start_pos - (kv_seqlens - block_sparse_size)).clamp_min(0) - ignore_range = torch.arange(0, block_sparse_size, dtype=ignore_pos.dtype, device=ignore_pos.device) + ignore_pos = (output_start_pos - (kv_seqlens - block_size)).clamp_min(0) + ignore_range = torch.arange(0, block_size, dtype=ignore_pos.dtype, device=ignore_pos.device) ignore_mask = (ignore_range[None, :] < ignore_pos[:, None]).flatten() token_ids = token_ids.clone() token_ids[ignore_mask] = -1 # find stop words sw_stopped = (token_ids[:, None] == stop_words).any(1) - sw_stopped = sw_stopped.view(batch_size, block_sparse_size) + sw_stopped = sw_stopped.view(batch_size, block_size) sw_stop_pos = sw_stopped.int().argmax(1) stop_pos = torch.where(stopped, stop_pos, sw_stop_pos) @@ -299,15 +299,15 @@ def _batch_stopping_criteria_dllm(token_ids: torch.Tensor, stop_words: torch.Ten """Batched stopping criteria.""" num_tokens = token_ids.size(0) batch_size = num_appendable_ids.size(0) - block_sparse_size = num_tokens // batch_size + block_size = num_tokens // batch_size - dllm_mask = dllm_mask.view(batch_size, block_sparse_size) + dllm_mask = dllm_mask.view(batch_size, block_size) is_unmasked = (dllm_mask == consts.DLLM_UNMASKED).all(dim=1) # check stop by num_new_tokens - num_appendable_ids -= is_unmasked * block_sparse_size + num_appendable_ids -= is_unmasked * block_size stopped = num_appendable_ids <= 0 - stop_pos = block_sparse_size - 1 + num_appendable_ids + stop_pos = block_size - 1 + num_appendable_ids # check stop words if stop_words is not None: @@ -436,9 +436,6 @@ def __init__(self, def _build_unmasking_processor(self): """Build unmasking processor.""" - # block_sparse_size = self.misc_config.block_sparse_size - # strategy = 'low_confidence_dynamic' if self.model_config.model_paradigm == 'dllm' else None - # denoising_steps = max(1, block_sparse_size // 2) dllm_config = self.misc_config.dllm_config unmasking_processor = UnmaskingProcessor(dllm_config) return unmasking_processor @@ -494,11 +491,11 @@ def warmup(self): def _slice_outs(self, inputs: torch.Tensor, seq_length: torch.LongTensor): """Slice outputs.""" if self.model_config.model_paradigm == 'dllm': - block_sparse_size = self.misc_config.block_sparse_size - if len(seq_length) * block_sparse_size == inputs.size(0): + block_length = self.misc_config.dllm_config.dllm_block_length + if len(seq_length) * block_length == inputs.size(0): return inputs last_idx = seq_length.cumsum(0) - block_range = torch.arange(-block_sparse_size, 0, device=last_idx.device) + block_range = torch.arange(-block_length, 0, device=last_idx.device) index = (last_idx[:, None] + block_range[None, :]).flatten() inputs = inputs[index] return inputs @@ -511,16 +508,16 @@ def _slice_outs(self, inputs: torch.Tensor, seq_length: torch.LongTensor): def _postprocess_forward_output_dllm(self, output: dict, inputs: ModelInputs, is_long_context: bool, return_logits: bool): """Post process for dllm.""" - block_sparse_size = self.misc_config.block_sparse_size + block_length = self.misc_config.dllm_config.dllm_block_length hidden_states = output['hidden_states'] if is_long_context: if not return_logits: - hidden_states = hidden_states[:, -block_sparse_size:] + hidden_states = hidden_states[:, -block_length:] else: hidden_states = hidden_states.to('cuda') else: seq_length = inputs.seq_length - is_decoding = seq_length.numel() * block_sparse_size == hidden_states.size(1) + is_decoding = seq_length.numel() * block_length == hidden_states.size(1) if not return_logits and not is_decoding: hidden_states = self._slice_outs(hidden_states[0], seq_length)[None] output['hidden_states'] = hidden_states @@ -746,11 +743,12 @@ def __update_dllm(next_token_ids: torch.Tensor, dllm_mask: torch.Tensor, seqlens return next_token_ids, dllm_mask, seqlens dllm_mask_token = self.model_config.dllm_mask_token - block_sparse_size = self.build_model_ctx.block_sparse_size + dllm_config = self.misc_config.dllm_config + dllm_block_length = dllm_config.dllm_block_length - # reshape to (batch, block_sparse_size) - next_token_ids = next_token_ids.view(-1, block_sparse_size).clone() - dllm_mask = dllm_mask.view(-1, block_sparse_size).clone() + # reshape to (batch, dllm_block_length) + next_token_ids = next_token_ids.view(-1, dllm_block_length).clone() + dllm_mask = dllm_mask.view(-1, dllm_block_length).clone() # flags is_cached = (dllm_mask == consts.DLLM_CACHED).all(dim=1) @@ -1088,8 +1086,8 @@ def _build_model(self): update_custom_module_map(custom_module_map) logger.debug(msg_with_rank(rank, 'build model.')) build_model_ctx = BuildModelContext(disable_vision_encoder=self.misc_config.disable_vision_encoder, - block_sparse_size=self.misc_config.block_sparse_size, - model_paradigm=self.model_config.model_paradigm) + model_paradigm=self.model_config.model_paradigm, + dllm_config=self.misc_config.dllm_config) patched_model = build_patched_model(self.model_config, device=device, model_format=self.misc_config.model_format, diff --git a/lmdeploy/pytorch/engine/unmasking.py b/lmdeploy/pytorch/engine/unmasking.py index 0433b5a4b1..6b77102e8f 100644 --- a/lmdeploy/pytorch/engine/unmasking.py +++ b/lmdeploy/pytorch/engine/unmasking.py @@ -23,24 +23,24 @@ def _get_scores(self, logits: torch.Tensor, token_ids: torch.Tensor): def _get_denoise_num(self): """Get denoise num.""" - block_sparse_size = self.dllm_config.block_sparse_size + block_size = self.dllm_config.dllm_block_length denoising_steps = self.dllm_config.denoising_steps if denoising_steps is None: - denoising_steps = block_sparse_size - num = block_sparse_size // self.dllm_config.denoising_steps - num = max(1, min(num, block_sparse_size)) + denoising_steps = block_size + num = block_size // self.dllm_config.denoising_steps + num = max(1, min(num, block_size)) return num def low_confidence_static(self, logits: torch.Tensor, token_ids: torch.Tensor, dllm_mask: torch.Tensor): """static.""" - block_sparse_size = self.dllm_config.block_sparse_size + block_size = self.dllm_config.dllm_block_length topk = self._get_denoise_num() scores = self._get_scores(logits, token_ids) is_masked = dllm_mask == DLLM_MASKED scores = torch.where(is_masked, scores, scores.new_zeros((1, ))) - scores = scores.view(-1, block_sparse_size) - dllm_mask = dllm_mask.view(-1, block_sparse_size) + scores = scores.view(-1, block_size) + dllm_mask = dllm_mask.view(-1, block_size) _, indices = scores.topk(topk, dim=-1) dllm_unmasked = dllm_mask.scatter(-1, indices, DLLM_UNMASKED) @@ -50,14 +50,14 @@ def low_confidence_static(self, logits: torch.Tensor, token_ids: torch.Tensor, d def low_confidence_dynamic(self, logits: torch.Tensor, token_ids: torch.Tensor, dllm_mask: torch.Tensor): """dynamic.""" - block_sparse_size = self.dllm_config.block_sparse_size + block_size = self.dllm_config.dllm_block_length threshold = self.dllm_config.confidence_threshold scores = self._get_scores(logits, token_ids) is_masked = dllm_mask == DLLM_MASKED scores = torch.where(is_masked, scores, scores.new_zeros((1, ))) - scores = scores.view(-1, block_sparse_size) - dllm_mask = dllm_mask.view(-1, block_sparse_size) + scores = scores.view(-1, block_size) + dllm_mask = dllm_mask.view(-1, block_size) _, indices = scores.topk(1, dim=-1) scores = scores.scatter(-1, indices, threshold) @@ -68,16 +68,16 @@ def low_confidence_dynamic(self, logits: torch.Tensor, token_ids: torch.Tensor, def sequential(self, dllm_mask: torch.Tensor): """sequential.""" - block_sparse_size = self.dllm_config.block_sparse_size + block_size = self.dllm_config.dllm_block_length denoise_num = self._get_denoise_num() - dllm_mask = dllm_mask.view(-1, block_sparse_size) + dllm_mask = dllm_mask.view(-1, block_size) is_masked = dllm_mask == DLLM_MASKED # get indices indices = is_masked.int().argmax(dim=1) ranges = torch.arange(0, denoise_num, device=indices.device, dtype=indices.dtype) indices = indices[:, None] + ranges[None, :] - indices = indices % block_sparse_size + indices = indices % block_size dllm_unmasked = dllm_mask.clone() dllm_unmasked = dllm_unmasked.scatter(-1, indices, DLLM_UNMASKED) @@ -92,9 +92,9 @@ def __call__(self, logits: torch.Tensor, input_ids: torch.Tensor, token_ids: tor if strategy is None: return dllm_mask - # reshape to [num_blocks, block_sparse_size] - block_sparse_size = self.dllm_config.block_sparse_size - dllm_mask = dllm_mask.unflatten(0, (-1, block_sparse_size)) + # reshape to [num_blocks, block_size] + block_size = self.dllm_config.dllm_block_length + dllm_mask = dllm_mask.unflatten(0, (-1, block_size)) is_same = (dllm_mask == dllm_mask[:, :1]).all(dim=1) first_mask = dllm_mask[:, 0] diff --git a/lmdeploy/pytorch/messages.py b/lmdeploy/pytorch/messages.py index 118d42a704..fcbb3ff200 100644 --- a/lmdeploy/pytorch/messages.py +++ b/lmdeploy/pytorch/messages.py @@ -162,7 +162,7 @@ class SequenceMeta: """Meta data shared by all sequence.""" block_size: int model_paradigm: str = 'llm' - block_sparse_size: int = 1 + dllm_block_length: int = 1 dllm_mask_token: int = 151669 @@ -778,16 +778,20 @@ def generated_ids(self) -> np.ndarray: def all_dllm_mask(self): return self.history_dllm_mask._token_ids[:self.num_all_ids] + @property + def dllm_block_length(self): + return self._seq_meta.dllm_block_length + def set_stop_pos(self, pos: int): - block_sparse_size = self._seq_meta.block_sparse_size - val = block_sparse_size - pos - 1 + dllm_block_length = self.dllm_block_length + val = dllm_block_length - pos - 1 self._num_valid_ids -= val self.num_new_tokens -= val def _update_token_ids_inputs(self, token_ids: np.ndarray, dllm_mask: np.ndarray): """Append tokens.""" num_tokens = len(token_ids) - block_sparse_size = self._seq_meta.block_sparse_size + dllm_block_length = self.dllm_block_length dllm_mask_token = self._seq_meta.dllm_mask_token new_token_ids = [token_ids] new_dllm_mask = [dllm_mask] @@ -804,8 +808,8 @@ def _update_token_ids_inputs(self, token_ids: np.ndarray, dllm_mask: np.ndarray) self.history_dllm_mask.resize(self.num_history_ids) num_tokens += num_remain_valid - # pad to align with block_sparse_size - num_pad = (-num_tokens) % block_sparse_size + # pad to align with dllm_block_length + num_pad = (-num_tokens) % dllm_block_length if num_pad > 0: pad_ids = np.full_like(token_ids, dllm_mask_token, shape=(num_pad, )) pad_mask = np.full_like(dllm_mask, DLLM_MASKED, shape=(num_pad, )) @@ -815,7 +819,7 @@ def _update_token_ids_inputs(self, token_ids: np.ndarray, dllm_mask: np.ndarray) token_ids = np.concatenate(new_token_ids) dllm_mask = np.concatenate(new_dllm_mask) - assert len(token_ids) % block_sparse_size == 0 + assert len(token_ids) % dllm_block_length == 0 self.history_cache.append(token_ids) self.history_dllm_mask.append(dllm_mask) @@ -827,9 +831,9 @@ def _update_token_ids_inputs(self, token_ids: np.ndarray, dllm_mask: np.ndarray) def _update_token_ids_decode(self, token_ids: np.ndarray, dllm_mask: np.ndarray): """Update token ids for decode.""" num_tokens = len(token_ids) - block_sparse_size = self._seq_meta.block_sparse_size + dllm_block_length = self.dllm_block_length dllm_mask_token = self._seq_meta.dllm_mask_token - assert num_tokens % block_sparse_size == 0 + assert num_tokens % dllm_block_length == 0 num_history_ids = self.num_history_ids token_ids[dllm_mask == DLLM_MASKED] = dllm_mask_token @@ -837,32 +841,32 @@ def _update_token_ids_decode(self, token_ids: np.ndarray, dllm_mask: np.ndarray) self.history_dllm_mask[num_history_ids:] = dllm_mask # check if all blocks are cached - last_mask = dllm_mask[-block_sparse_size:] + last_mask = dllm_mask[-dllm_block_length:] is_unmasked = np.all(last_mask == DLLM_UNMASKED) is_cached = np.all(last_mask == DLLM_CACHED) if is_unmasked: - num_new = block_sparse_size - self._num_valid_ids % block_sparse_size + num_new = dllm_block_length - self._num_valid_ids % dllm_block_length self._num_valid_ids += num_new self.num_new_tokens += num_new if is_cached: # add new block - new_token_ids = np.full_like(token_ids, dllm_mask_token, shape=(block_sparse_size, )) - new_dllm_mask = np.full_like(dllm_mask, DLLM_MASKED, shape=(block_sparse_size, )) + new_token_ids = np.full_like(token_ids, dllm_mask_token, shape=(dllm_block_length, )) + new_dllm_mask = np.full_like(dllm_mask, DLLM_MASKED, shape=(dllm_block_length, )) self.history_cache.append(new_token_ids) self.history_dllm_mask.append(new_dllm_mask) self._num_history_ids += self._num_token_ids - self._num_token_ids = block_sparse_size + self._num_token_ids = dllm_block_length def _update_token_ids_prefill(self, token_ids: np.ndarray, dllm_mask: np.ndarray): """Update token ids for prefill.""" - block_sparse_size = self._seq_meta.block_sparse_size + dllm_block_length = self.dllm_block_length num_history_ids = self.num_history_ids # fill input cache - if self.num_token_ids > block_sparse_size: - end = self.num_token_ids - block_sparse_size + if self.num_token_ids > dllm_block_length: + end = self.num_token_ids - dllm_block_length self.history_dllm_mask[num_history_ids:end] = DLLM_CACHED self._num_history_ids += end self._num_token_ids -= end diff --git a/lmdeploy/pytorch/model_inputs.py b/lmdeploy/pytorch/model_inputs.py index f3d18112c5..2d43b33a4e 100644 --- a/lmdeploy/pytorch/model_inputs.py +++ b/lmdeploy/pytorch/model_inputs.py @@ -9,7 +9,7 @@ # from torch import distributed as dist import lmdeploy.pytorch.distributed as dist from lmdeploy.pytorch.backends import get_backend -from lmdeploy.pytorch.config import ModelConfig +from lmdeploy.pytorch.config import DLLMConfig, ModelConfig from lmdeploy.pytorch.multimodal.data_type import MultiModalTensor @@ -293,8 +293,8 @@ def make_dummy(cls, """Make dummy inputs.""" model_paradigm = build_ctx.model_paradigm if model_paradigm == 'dllm': - block_sparse_size = build_ctx.block_sparse_size - max_q_seqlen = block_sparse_size + block_size = build_ctx.dllm_config.dllm_block_length + max_q_seqlen = block_size else: max_q_seqlen = 1 num_tokens = batch_size * max_q_seqlen @@ -475,8 +475,8 @@ def get_mask_and_position_ids(cls, inputs: ModelInputs): class BuildModelContext: """Context for building model.""" disable_vision_encoder: bool = False - block_sparse_size: int = 1 model_paradigm: str = 'llm' + dllm_config: DLLMConfig = None class StepContextManager: diff --git a/lmdeploy/pytorch/models/sdar.py b/lmdeploy/pytorch/models/sdar.py index ab1f4ced29..9f69c525e0 100644 --- a/lmdeploy/pytorch/models/sdar.py +++ b/lmdeploy/pytorch/models/sdar.py @@ -40,7 +40,7 @@ def __init__(self, config: PretrainedConfig, dtype: torch.dtype = None, device: # rotary embedding self.apply_rotary_pos_emb = ApplyRotaryEmb() - block_sparse_size = config.block_sparse_size + dllm_block_length = config.dllm_block_length # attention self.attn_fwd = Attention( @@ -49,7 +49,7 @@ def __init__(self, config: PretrainedConfig, dtype: torch.dtype = None, device: num_kv_heads=num_key_value_heads, v_head_size=head_dim, sliding_window=config.sliding_window, - block_sparse_size=block_sparse_size, + block_sparse_size=dllm_block_length, ) # o_proj @@ -301,7 +301,7 @@ def __init__(self, super().__init__() self.config = config self.ctx_mgr = ctx_mgr - config.block_sparse_size = ctx_mgr.build_ctx.block_sparse_size + config.dllm_block_length = ctx_mgr.build_ctx.dllm_config.dllm_block_length # build model self.model = SDARModel(config, dtype=dtype, device=device) # build lm_head From 34e41aa4ea1fbd3f86b1627b01b4fb4bff3ae933 Mon Sep 17 00:00:00 2001 From: grimoire Date: Fri, 5 Sep 2025 15:56:44 +0800 Subject: [PATCH 16/29] strategy + abstruct factory --- .../pytorch/backends/cuda/graph_runner.py | 14 +- lmdeploy/pytorch/config.py | 4 +- lmdeploy/pytorch/engine/engine.py | 84 +---- lmdeploy/pytorch/engine/logits_process.py | 238 -------------- lmdeploy/pytorch/engine/model_agent.py | 300 ++++-------------- lmdeploy/pytorch/messages.py | 6 +- lmdeploy/pytorch/model_inputs.py | 50 +-- lmdeploy/pytorch/strategies/__init__.py | 16 + lmdeploy/pytorch/strategies/ar/__init__.py | 49 +++ lmdeploy/pytorch/strategies/ar/cudagraph.py | 9 + lmdeploy/pytorch/strategies/ar/engine.py | 20 ++ lmdeploy/pytorch/strategies/ar/model_agent.py | 108 +++++++ .../pytorch/strategies/ar/model_inputs.py | 21 ++ lmdeploy/pytorch/strategies/ar/sampling.py | 191 +++++++++++ lmdeploy/pytorch/strategies/base/__init__.py | 42 +++ lmdeploy/pytorch/strategies/base/cudagraph.py | 10 + lmdeploy/pytorch/strategies/base/engine.py | 16 + .../pytorch/strategies/base/model_agent.py | 132 ++++++++ .../pytorch/strategies/base/model_inputs.py | 54 ++++ lmdeploy/pytorch/strategies/base/sampling.py | 17 + lmdeploy/pytorch/strategies/dllm/__init__.py | 53 ++++ lmdeploy/pytorch/strategies/dllm/cudagraph.py | 13 + lmdeploy/pytorch/strategies/dllm/engine.py | 35 ++ .../pytorch/strategies/dllm/model_agent.py | 218 +++++++++++++ .../pytorch/strategies/dllm/model_inputs.py | 24 ++ lmdeploy/pytorch/strategies/dllm/sampling.py | 57 ++++ .../{engine => strategies/dllm}/unmasking.py | 0 27 files changed, 1179 insertions(+), 602 deletions(-) create mode 100644 lmdeploy/pytorch/strategies/__init__.py create mode 100644 lmdeploy/pytorch/strategies/ar/__init__.py create mode 100644 lmdeploy/pytorch/strategies/ar/cudagraph.py create mode 100644 lmdeploy/pytorch/strategies/ar/engine.py create mode 100644 lmdeploy/pytorch/strategies/ar/model_agent.py create mode 100644 lmdeploy/pytorch/strategies/ar/model_inputs.py create mode 100644 lmdeploy/pytorch/strategies/ar/sampling.py create mode 100644 lmdeploy/pytorch/strategies/base/__init__.py create mode 100644 lmdeploy/pytorch/strategies/base/cudagraph.py create mode 100644 lmdeploy/pytorch/strategies/base/engine.py create mode 100644 lmdeploy/pytorch/strategies/base/model_agent.py create mode 100644 lmdeploy/pytorch/strategies/base/model_inputs.py create mode 100644 lmdeploy/pytorch/strategies/base/sampling.py create mode 100644 lmdeploy/pytorch/strategies/dllm/__init__.py create mode 100644 lmdeploy/pytorch/strategies/dllm/cudagraph.py create mode 100644 lmdeploy/pytorch/strategies/dllm/engine.py create mode 100644 lmdeploy/pytorch/strategies/dllm/model_agent.py create mode 100644 lmdeploy/pytorch/strategies/dllm/model_inputs.py create mode 100644 lmdeploy/pytorch/strategies/dllm/sampling.py rename lmdeploy/pytorch/{engine => strategies/dllm}/unmasking.py (100%) diff --git a/lmdeploy/pytorch/backends/cuda/graph_runner.py b/lmdeploy/pytorch/backends/cuda/graph_runner.py index dcbeb28d38..deb6c66bfd 100644 --- a/lmdeploy/pytorch/backends/cuda/graph_runner.py +++ b/lmdeploy/pytorch/backends/cuda/graph_runner.py @@ -9,6 +9,7 @@ from lmdeploy.pytorch.config import BackendConfig, CacheConfig, ModelConfig from lmdeploy.pytorch.model_inputs import StepContext, get_step_ctx_manager from lmdeploy.pytorch.models.utils.cudagraph import CudaGraphMeta +from lmdeploy.pytorch.strategies.base import StrategyFactoryBase from lmdeploy.utils import get_logger from ..graph_runner import GraphRunner @@ -147,6 +148,11 @@ def __init__(self, model: torch.nn.Module, model_config: ModelConfig, cache_conf self._runner_map: Dict[Any, CUDASingleGraphRunner] = dict() self.has_try_compile_model: bool = False + # strategy factory + build_ctx = model.ctx_mgr.build_ctx + strategy_factory: StrategyFactoryBase = build_ctx.strategy_factory + self.cudagraph_strategy = strategy_factory.build_cudagraph_strategy() + def check_enable_graph(self): """Check enable graph.""" if self.backend_config.eager_mode: @@ -191,13 +197,7 @@ def _get_max_tokens(self, graph_key: tuple): max_batches = graph_key[0] is_decoding = graph_key[1] assert is_decoding - model_paradigm = self.model_config.model_paradigm - if model_paradigm == 'dllm': - step_mgr = get_step_ctx_manager() - build_ctx = step_mgr.build_ctx - dllm_block_length = build_ctx.dllm_config.dllm_block_length - return max_batches * dllm_block_length - return max_batches + return self.cudagraph_strategy.get_max_tokens(max_batches) def __call__(self, **kwargs): """call.""" diff --git a/lmdeploy/pytorch/config.py b/lmdeploy/pytorch/config.py index 8c36ec2782..c75358d9cc 100644 --- a/lmdeploy/pytorch/config.py +++ b/lmdeploy/pytorch/config.py @@ -200,7 +200,7 @@ class ModelConfig: cogvlm_style: bool = False custom_module_map: Dict[str, setattr] = None use_flash_mla: bool = False - model_paradigm: str = 'llm' + model_paradigm: str = 'ar' dllm_mask_token: int = 0 def get_head_size(self): @@ -304,7 +304,6 @@ class MiscConfig: hf_overrides: Dict[str, Any] = None disable_vision_encoder: bool = False logprobs_mode: str = None - dllm_block_length: int = 1 dllm_config: DLLMConfig = None @classmethod @@ -323,7 +322,6 @@ def from_engine_config(cls, engine_config: PytorchEngineConfig): model_format=engine_config.model_format, hf_overrides=engine_config.hf_overrides, disable_vision_encoder=engine_config.disable_vision_encoder, - dllm_block_length=engine_config.dllm_block_length, logprobs_mode=engine_config.logprobs_mode, dllm_config=dllm_config) return misc_config diff --git a/lmdeploy/pytorch/engine/engine.py b/lmdeploy/pytorch/engine/engine.py index 754830f0b3..4b5f1ad733 100644 --- a/lmdeploy/pytorch/engine/engine.py +++ b/lmdeploy/pytorch/engine/engine.py @@ -23,10 +23,10 @@ from ..messages import MessageStatus, SchedulerSequence, UpdateTokenMode from ..model_inputs import ModelInputs, VisionModelInputs from ..paging import Scheduler +from ..strategies import build_strategy_factory from .base import EngineBase from .engine_checker import EngineChecker from .executor import build_executor -from .logits_process import SamplingInputs from .model_agent import BatchedOutputs from .request import Request, RequestManager, RequestType, Response @@ -383,6 +383,13 @@ def __init__(self, dtype=engine_config.dtype) self.executor.init() + # strategies + self.strategy_factory = build_strategy_factory(self.model_config, self.executor.misc_config) + self.sampling_strategy = self.strategy_factory.build_sampling_strategy() + self.model_agent_strategy = self.strategy_factory.build_model_agent_strategy() + self.engine_strategy = self.strategy_factory.build_engine_strategy(cache_config=cache_config, + scheduler_config=scheduler_config) + self.input_processor = self.executor.get_input_processor() cache_config = self.executor.cache_config self.adapter_manager = self._build_adapter_manager(adapters) @@ -847,7 +854,7 @@ def update_running(self, running: SeqList, batched_outputs: BatchedOutputs, is_d if model_metas is None: model_metas = [None] * len(running) if self.model_config.model_paradigm == 'dllm': - dllm_mask = batched_outputs.dllm_mask + dllm_mask = batched_outputs.extra_outputs.dllm_mask stop_pos = batched_outputs.stop_pos return self._update_running_dllm(running, next_token_ids, dllm_mask, stopped, model_metas, is_decoding, stop_pos) @@ -923,71 +930,15 @@ def _make_infer_outputs( def _make_forward_inputs(self, prefill: bool, enable_empty: bool = False): """Make forward inputs.""" - prefill_interval = self.scheduler_config.prefill_interval - - def __get_prealloc(prefill_interval: int): - model_paradigm = self.model_config.model_paradigm - if model_paradigm == 'dllm': - block_size = self.cache_config.block_size - dllm_block_length = self.misc_config.dllm_config.dllm_block_length - num_blocks = min(prefill_interval // 2, block_size // dllm_block_length) - return num_blocks * dllm_block_length - else: - return prefill_interval - - def __get_num_loops(is_prefill: bool, prefill_interval: int): - model_paradigm = self.model_config.model_paradigm - if is_prefill: - return 1 - if model_paradigm == 'dllm': - block_size = self.cache_config.block_size - dllm_block_length = self.misc_config.dllm_config.dllm_block_length - max_num_loops = block_size // dllm_block_length * 2 - num_loops = min(prefill_interval, max_num_loops) - return num_loops - else: - return prefill_interval - - def __get_num_appendable_ids(seqs: SeqList): - """Get num appendable ids.""" - num_appendable = [seq.sampling_param.max_new_tokens - seq.num_new_tokens for seq in seqs] - num_appendable = torch.tensor(num_appendable) - if self.model_config.model_paradigm == 'dllm': - dllm_block_length = self.misc_config.dllm_config.dllm_block_length - remain = [seq.num_valid_ids % dllm_block_length for seq in seqs] - num_appendable += torch.tensor(remain) - return num_appendable - - def __get_dllm_mask(seqs: SeqList): - """Get dllm mask.""" - if self.model_config.model_paradigm != 'dllm': - return None - dllm_masks = [seq.dllm_mask for seq in seqs] - dllm_masks = torch.as_tensor(np.concatenate(dllm_masks)) - return dllm_masks def __need_logits(seqs: SeqList): """Need logits.""" return any(seq.return_logits for seq in seqs) - def __make_sampling_inputs(seqs: SeqList): - pad_id = self.model_config.bos_token_id - pad_id = 0 if pad_id is None else pad_id - if self.model_config.model_paradigm == 'dllm': - from .logits_process import SamplingInputsDLLM - return SamplingInputsDLLM.from_sampling_params(seqs, - pad_token_id=pad_id, - dllm_config=self.misc_config.dllm_config) - return SamplingInputs.from_sampling_params(seqs, pad_token_id=pad_id) - - def __get_output_start_pos(seqs: SeqList): - pos = [seq.output_start_pos for seq in seqs] - return torch.tensor(pos) - scheduler = self.scheduler logger.debug(f'Make forward inputs with prefill={prefill}, enable_empty={enable_empty}') - prealloc_size = __get_prealloc(prefill_interval) + prealloc_size = self.engine_strategy.get_prealloc_size(not prefill) scheduler_output = scheduler.schedule(is_prefill=prefill, prealloc_size=prealloc_size) if enable_empty and len(scheduler_output.running) == 0: @@ -996,9 +947,10 @@ def __get_output_start_pos(seqs: SeqList): # schedule decoding if no valid prefill reqs. if prefill and len(scheduler_output.running) == 0 and self.engine_config.role != EngineRole.Prefill: prefill = False + prealloc_size = self.engine_strategy.get_prealloc_size(not prefill) scheduler_output = scheduler.schedule(is_prefill=prefill, prealloc_size=prealloc_size) - num_loops = __get_num_loops(prefill, prefill_interval) + num_loops = self.engine_strategy.get_num_loops(not prefill) running = scheduler_output.running swap_in_map = scheduler_output.swap_in_map swap_out_map = scheduler_output.swap_out_map @@ -1008,11 +960,10 @@ def __get_output_start_pos(seqs: SeqList): # create inputs inputs = self.create_model_inputs(running, prefill) - sampling_inputs = __make_sampling_inputs(running) - num_appendable_ids = __get_num_appendable_ids(running) + sampling_inputs = self.sampling_strategy.make_sampling_inputs(running) return_logits = __need_logits(running) - dllm_mask = __get_dllm_mask(running) - output_start_pos = __get_output_start_pos(running) + extra_inputs = self.model_agent_strategy.make_extra_inputs(running) + stopping_criteria = self.model_agent_strategy.make_stopping_criteria(running) sync_long_context = inputs.input_ids.numel() > self.cache_config.max_prefill_token_num return dict( @@ -1022,12 +973,11 @@ def __get_output_start_pos(seqs: SeqList): swap_out_map=swap_out_map, loop_count=num_loops, sampling_inputs=sampling_inputs, - num_appendable_ids=num_appendable_ids, + stopping_criteria=stopping_criteria, return_logits=return_logits, is_dummy=False, sync_long_context=sync_long_context, - dllm_mask=dllm_mask, - output_start_pos=output_start_pos, + extra_inputs=extra_inputs, ) async def _await_forward_event(self, forward_event: asyncio.Event): diff --git a/lmdeploy/pytorch/engine/logits_process.py b/lmdeploy/pytorch/engine/logits_process.py index a1ace91884..b30fbb3992 100644 --- a/lmdeploy/pytorch/engine/logits_process.py +++ b/lmdeploy/pytorch/engine/logits_process.py @@ -7,7 +7,6 @@ import torch from lmdeploy.messages import LogitsProcessor -from lmdeploy.pytorch.config import DLLMConfig from lmdeploy.tokenizer import Tokenizer from ..messages import SchedulerSequence @@ -113,44 +112,6 @@ def _guided_sampling(response_formats: Tuple[Dict], scores: torch.Tensor, guided SeqList = List[SchedulerSequence] -def _gather_all_ids(pad_id: int, seqs: SeqList, sampling_inputs: 'SamplingInputs'): - """Gather history.""" - if sampling_inputs.repetition_penalty is None and not any(sampling_inputs.logits_processors): - return None - batch = len(seqs) - max_len = max(seq.num_valid_ids for seq in seqs) - output = torch.full((batch, max_len), pad_id, dtype=torch.int64) - for idx, seq in enumerate(seqs): - h_len = seq.num_valid_ids - if h_len == 0: - continue - h_ids = torch.from_numpy(seq.valid_ids) - output[idx, -h_len:] = h_ids - return output - - -def _gather_guided_input_ids(pad_id: int, seqs: SeqList, sampling_inputs: 'SamplingInputs'): - """Gather input ids for guided decode.""" - if not any(sampling_inputs.response_formats or ()): - return None - batch = len(seqs) - max_len = max(seq.num_new_tokens for seq in seqs) - output = torch.full((batch, max_len), pad_id, dtype=torch.int64) - for idx, seq in enumerate(seqs): - h_len = seq.num_new_tokens - if h_len == 0: - continue - h_ids = torch.from_numpy(seq.generated_ids) - output[idx, -h_len:] = h_ids - return output - - -def _get_num_ignore_eos(seqs: SeqList): - """Get num ignore eos.""" - ret = [seq.sampling_param.min_new_tokens - seq.num_new_tokens for seq in seqs] - return torch.tensor(ret) - - @dataclass class SamplingInputs: temperature: torch.Tensor = None @@ -174,140 +135,6 @@ class SamplingInputs: num_ignore_eos: torch.Tensor = None batch_size: int = 0 - @classmethod - def from_sampling_params(cls, seqs: List[SchedulerSequence], pad_token_id: int = 0): - """From samplingg params.""" - batch_size = len(seqs) - temperature = [None] * batch_size - repetition_penalty = [None] * batch_size - top_k = [None] * batch_size - top_p = [None] * batch_size - min_p = [None] * batch_size - bad_words = [None] * batch_size - stop_words = [None] * batch_size - random_seeds = [torch.seed() & 0xffffffff] * batch_size - random_offsets = [None] * batch_size - response_formats = [None] * batch_size - logits_processors = [None] * batch_size - num_logprobs = [None] * batch_size - - def __gather_params(): - """Gather params.""" - for idx, seq in enumerate(seqs): - param = seq.sampling_param - temperature[idx] = param.temperature - repetition_penalty[idx] = param.repetition_penalty - top_k[idx] = param.top_k - top_p[idx] = param.top_p - min_p[idx] = param.min_p - random_offsets[idx] = seq.num_valid_ids - response_formats[idx] = param.response_format - if param.random_seed is not None: - random_seeds[idx] = param.random_seed & 0xffffffff - - bw = param.bad_words - sw = param.stop_words - if (not param.ignore_eos and seq.num_new_tokens < param.min_new_tokens): - bw = bw + sw - bad_words[idx] = bw - stop_words[idx] = sw - logits_processors[idx] = param.logits_processors - num_logprobs[idx] = param.num_logprobs - - def __get_topp(top_p): - """Get topp.""" - min_top_p = min(top_p) - if min_top_p == 1.0: - top_p = None - else: - top_p = torch.tensor(top_p) - return top_p, min_top_p - - def __get_minp(min_p): - """Get minp.""" - max_min_p = max(min_p) - if max_min_p == 0.0: - min_p = None - else: - min_p = torch.Tensor(min_p) - return min_p - - def __get_bad_words(bad_words): - """Get bad words.""" - max_bw_len = max(len(bw) for bw in bad_words) - if max_bw_len == 0: - return None, None - if all(len(bw) == max_bw_len for bw in bad_words): - ret = torch.tensor(bad_words) - mask = torch.ones_like(ret, dtype=bool) - return ret, mask - ret = torch.full((batch_size, max_bw_len), -1, dtype=torch.int64) - for idx, bw in enumerate(bad_words): - bw_len = len(bw) - if bw_len == 0: - continue - bw = ret.new_tensor(bw) - ret[idx, :bw_len] = bw - - mask = ret >= 0 - ret = ret.where(mask, 0) - return ret, mask - - __gather_params() - - if all(rp == 1.0 for rp in repetition_penalty): - repetition_penalty = None - else: - repetition_penalty = torch.tensor(repetition_penalty) - - temperature = torch.tensor(temperature) - - bad_words, bad_mask = __get_bad_words(bad_words) - stop_words, stop_mask = __get_bad_words(stop_words) - - max_top_k = max(top_k) - if min(top_k) <= 0: - max_top_k = 0 - if max_top_k == 1: - top_k = None - top_p, min_top_p = None, 1.0 - min_p = None - random_seeds = None - random_offsets = None - else: - top_k = torch.tensor(top_k) - top_p, min_top_p = __get_topp(top_p) - min_p = __get_minp(min_p) - random_seeds = torch.tensor(random_seeds) - random_offsets = torch.tensor(random_offsets) - - max_num_logprobs = max(num_logprobs) - - sampling_input = cls( - temperature=temperature, - bad_words=bad_words, - bad_mask=bad_mask, - stop_words=stop_words, - stop_mask=stop_mask, - repetition_penalty=repetition_penalty, - top_k=top_k, - top_p=top_p, - min_p=min_p, - random_seeds=random_seeds, - random_offsets=random_offsets, - response_formats=tuple(response_formats), - max_top_k=max_top_k, - min_top_p=min_top_p, - logits_processors=logits_processors, - max_num_logprobs=max_num_logprobs, - batch_size=batch_size, - ) - - sampling_input.all_ids = _gather_all_ids(pad_token_id, seqs, sampling_input) - sampling_input.guided_input_ids = _gather_guided_input_ids(pad_token_id, seqs, sampling_input) - sampling_input.num_ignore_eos = _get_num_ignore_eos(seqs) - return sampling_input - def to_device(self, device: str, non_blocking: bool = False): """To device.""" out_dict = dict() @@ -320,71 +147,6 @@ def to_device(self, device: str, non_blocking: bool = False): return SamplingInputs(**out_dict) - def step(self, next_token_ids: torch.Tensor, **kwargs): - """To next step.""" - self.num_ignore_eos = self.num_ignore_eos - 1 - if self.all_ids is not None: - self.all_ids = torch.cat([self.all_ids, next_token_ids[:, None]], 1) - if self.guided_input_ids is not None: - self.guided_input_ids = torch.cat([self.guided_input_ids, next_token_ids[:, None]], 1) - - -@dataclass -class SamplingInputsDLLM(SamplingInputs): - - @classmethod - def from_sampling_params(cls, seqs: List[SchedulerSequence], pad_token_id: int = 0, dllm_config: DLLMConfig = None): - """From samplingg params.""" - out = super().from_sampling_params(seqs, pad_token_id) - assert dllm_config is not None - dllm_block_length = dllm_config.dllm_block_length - - # repeat tensor - update_attr_names = [ - 'temperature', - 'bad_words', - 'bad_mask', - 'stop_words', - 'stop_mask', - 'repetition_penalty', - 'top_k', - 'top_p', - 'min_p', - 'random_seeds', - 'random_offsets', - 'all_ids', - 'guided_input_ids', - 'num_ignore_eos', - ] - for name in update_attr_names: - attr = getattr(out, name) - if attr is None: - continue - repeats = (dllm_block_length, ) + (1, ) * (attr.dim()) - attr = attr[None].repeat(*repeats).flatten(0, 1) - setattr(out, name, attr) - - if len(out.response_formats) > 0: - new_resp_formats = [] - for resp in out.response_formats: - new_resp_formats += [resp] * dllm_block_length - out.response_formats = tuple(new_resp_formats) - - out.batch_size *= dllm_block_length - - return out - - def step(self, next_token_ids: torch.Tensor, dllm_mask: torch.Tensor, **kwargs): - """To next step.""" - from lmdeploy.pytorch import consts - batch_size = self.batch_size - dllm_block_size = next_token_ids.numel() // batch_size - DLLM_UNMASKED = consts.DLLM_UNMASKED - is_unmasked = (dllm_mask == DLLM_UNMASKED).view(batch_size, -1).all(dim=1, keepdim=True) - num_ignore_eos = self.num_ignore_eos.view(batch_size, -1) - num_ignore_eos = torch.where(is_unmasked, num_ignore_eos - dllm_block_size, num_ignore_eos) - self.num_ignore_eos = num_ignore_eos.flatten() - def _apply_custom_logits_processors(batched_logits_processors, all_ids, logits): """Apply custom logits processors.""" diff --git a/lmdeploy/pytorch/engine/model_agent.py b/lmdeploy/pytorch/engine/model_agent.py index d177ea4f10..7b7d6a8905 100644 --- a/lmdeploy/pytorch/engine/model_agent.py +++ b/lmdeploy/pytorch/engine/model_agent.py @@ -14,7 +14,6 @@ import torch.distributed as dist from torch.profiler import ProfilerActivity, profile, record_function -from lmdeploy.pytorch import consts from lmdeploy.pytorch.disagg.config import EngineRole from lmdeploy.serve.openai.protocol import UpdateParamsRequest from lmdeploy.utils import get_logger @@ -25,11 +24,12 @@ from ..distributed import DistContext, get_dist_manager from ..model_inputs import ModelInputs, step_ctx_manager from ..models.patch import BuildModelContext, add_adapters, build_patched_model, update_custom_module_map +from ..strategies import build_strategy_factory +from ..strategies.base.model_agent import ExtraInputs, ExtraOutputs, StoppingCriteria from ..utils import get_gpu_memory from ..weight_loader.model_weight_loader import load_model_weights from .cache_engine import CacheEngine from .logits_process import FusedLogitsProcessor, SamplingInputs -from .unmasking import UnmaskingProcessor logger = get_logger('lmdeploy') @@ -69,7 +69,7 @@ class BatchedOutputs: model_metas: List[Dict[str, Any]] = None logprobs: Optional[BatchedLogProbs] = None new_token_timestamp: int = 0 - dllm_mask: Optional[torch.Tensor] = None + extra_outputs: Optional[ExtraOutputs] = None def to_cpu(self): """To cpu.""" @@ -247,81 +247,6 @@ def model_forward( return dict(hidden_states=output, model_metas=model_metas) -def _batch_stopping_criteria_default(token_ids: torch.Tensor, stop_words: torch.Tensor, - num_appendable_ids: torch.Tensor): - """Batched stopping criteria.""" - num_appendable_ids = num_appendable_ids - 1 - stopped = num_appendable_ids <= 0 - stop_pos = torch.zeros_like(num_appendable_ids) - if stop_words is not None: - sw_stopped = (token_ids[:, None] == stop_words).any(1) - stopped = stopped | sw_stopped - one_ids = torch.clamp_max(num_appendable_ids, 0) - num_appendable_ids = torch.where(sw_stopped, one_ids, num_appendable_ids) - return stopped, stop_pos, num_appendable_ids - - -def _check_stopwords_dllm(token_ids: torch.Tensor, stop_words: torch.Tensor, is_unmasked: torch.Tensor, - stopped: torch.Tensor, stop_pos: torch.Tensor, num_appendable_ids: torch.Tensor, - output_start_pos: torch.Tensor, inputs: ModelInputs): - num_tokens = token_ids.size(0) - batch_size = num_appendable_ids.size(0) - block_size = num_tokens // batch_size - - # blocks might contain stop words in prev-round chat - # these stop words should be ignored - kv_seqlens = inputs.history_lengths + inputs.seq_length - ignore_pos = (output_start_pos - (kv_seqlens - block_size)).clamp_min(0) - ignore_range = torch.arange(0, block_size, dtype=ignore_pos.dtype, device=ignore_pos.device) - ignore_mask = (ignore_range[None, :] < ignore_pos[:, None]).flatten() - token_ids = token_ids.clone() - token_ids[ignore_mask] = -1 - - # find stop words - sw_stopped = (token_ids[:, None] == stop_words).any(1) - sw_stopped = sw_stopped.view(batch_size, block_size) - sw_stop_pos = sw_stopped.int().argmax(1) - - stop_pos = torch.where(stopped, stop_pos, sw_stop_pos) - sw_stopped = sw_stopped.any(dim=1) - sw_stopped = sw_stopped & is_unmasked - stopped = stopped | sw_stopped - - # update num_appendable_ids - one_ids = torch.clamp_max(num_appendable_ids, 0) - num_appendable_ids = torch.where(sw_stopped, one_ids, num_appendable_ids) - - return stopped, stop_pos, num_appendable_ids - - -def _batch_stopping_criteria_dllm(token_ids: torch.Tensor, stop_words: torch.Tensor, num_appendable_ids: torch.Tensor, - dllm_mask: torch.Tensor, output_start_pos: torch.Tensor, inputs: ModelInputs): - """Batched stopping criteria.""" - num_tokens = token_ids.size(0) - batch_size = num_appendable_ids.size(0) - block_size = num_tokens // batch_size - - dllm_mask = dllm_mask.view(batch_size, block_size) - is_unmasked = (dllm_mask == consts.DLLM_UNMASKED).all(dim=1) - - # check stop by num_new_tokens - num_appendable_ids -= is_unmasked * block_size - stopped = num_appendable_ids <= 0 - stop_pos = block_size - 1 + num_appendable_ids - - # check stop words - if stop_words is not None: - stopped, stop_pos, num_appendable_ids = _check_stopwords_dllm(token_ids, - stop_words, - is_unmasked, - stopped, - stop_pos, - num_appendable_ids, - output_start_pos=output_start_pos, - inputs=inputs) - return stopped, stop_pos, num_appendable_ids - - def _try_to_cuda(val, non_blocking: bool = False): if val is None: return val @@ -431,14 +356,10 @@ def __init__(self, self.enable_microbatch_decode_batchsize_threshold = \ int(getenv('ENABLE_MICROBATCH_DECODE_BATCHSIZE_THRESHOLD', 2)) - # dllm - self.unmasking_processor = self._build_unmasking_processor() - - def _build_unmasking_processor(self): - """Build unmasking processor.""" - dllm_config = self.misc_config.dllm_config - unmasking_processor = UnmaskingProcessor(dllm_config) - return unmasking_processor + # strategy + self.strategy_factory = build_strategy_factory(model_config, misc_config) + self.inputs_strategy = self.strategy_factory.build_model_inputs_strategy() + self.agent_strategy = self.strategy_factory.build_model_agent_strategy() @contextmanager def all_context(self): @@ -470,101 +391,34 @@ def warmup(self): num_tokens = max_batches # warmup prefill - inputs = ModelInputs.make_dummy(max_batches, - is_decoding=False, - device='cuda', - vocab_size=self.model_config.vocab_size, - build_ctx=self.build_model_ctx) + inputs = self.inputs_strategy.make_dummy(max_batches, + is_decoding=False, + device='cuda', + vocab_size=self.model_config.vocab_size) self._forward_impl(inputs) # warmup decoding(with cuda graph) capture_batch_sizes = self.patched_model.get_capture_batch_sizes() capture_batch_sizes = sorted(capture_batch_sizes, reverse=True) for num_tokens in capture_batch_sizes: - inputs = ModelInputs.make_dummy(num_tokens, - is_decoding=True, - device='cuda', - vocab_size=self.model_config.vocab_size, - build_ctx=self.build_model_ctx) + inputs = self.inputs_strategy.make_dummy(num_tokens, + is_decoding=True, + device='cuda', + vocab_size=self.model_config.vocab_size) self._forward_impl(inputs) def _slice_outs(self, inputs: torch.Tensor, seq_length: torch.LongTensor): """Slice outputs.""" - if self.model_config.model_paradigm == 'dllm': - block_length = self.misc_config.dllm_config.dllm_block_length - if len(seq_length) * block_length == inputs.size(0): - return inputs - last_idx = seq_length.cumsum(0) - block_range = torch.arange(-block_length, 0, device=last_idx.device) - index = (last_idx[:, None] + block_range[None, :]).flatten() - inputs = inputs[index] - return inputs - else: - if len(seq_length) == inputs.size(0): - return inputs - last_idx = seq_length.cumsum(-1) - 1 - return inputs[last_idx] - - def _postprocess_forward_output_dllm(self, output: dict, inputs: ModelInputs, is_long_context: bool, - return_logits: bool): - """Post process for dllm.""" - block_length = self.misc_config.dllm_config.dllm_block_length - hidden_states = output['hidden_states'] - if is_long_context: - if not return_logits: - hidden_states = hidden_states[:, -block_length:] - else: - hidden_states = hidden_states.to('cuda') - else: - seq_length = inputs.seq_length - is_decoding = seq_length.numel() * block_length == hidden_states.size(1) - if not return_logits and not is_decoding: - hidden_states = self._slice_outs(hidden_states[0], seq_length)[None] - output['hidden_states'] = hidden_states - return output + return self.agent_strategy.slice_outputs(inputs, seq_length) - def _postprocess_forward_output_default(self, output: dict, inputs: ModelInputs, is_long_context: bool, - return_logits: bool): - """Post process forward output default.""" + def _postprocess_forward_output(self, output: dict, inputs: ModelInputs): + """Post process forward output.""" hidden_states = output['hidden_states'] - if not is_long_context: - seq_length = inputs.seq_length - is_decoding = seq_length.numel() != hidden_states.size(1) - if not return_logits and not is_decoding: - hidden_states = self._slice_outs(hidden_states[0], seq_length)[None] - else: - if not return_logits: - last_token_loc = [-1] - hidden_states = hidden_states[:, last_token_loc] - else: - hidden_states = hidden_states.to('cuda') + seq_length = inputs.seq_length + hidden_states = self._slice_outs(hidden_states[0], seq_length)[None] output['hidden_states'] = hidden_states return output - def _postprocess_forward_output(self, output: dict, inputs: ModelInputs, is_long_context: bool, - return_logits: bool): - """Post process forward output.""" - if self.model_config.model_paradigm == 'dllm': - return self._postprocess_forward_output_dllm(output, inputs, is_long_context, return_logits) - else: - return self._postprocess_forward_output_default(output, inputs, is_long_context, return_logits) - - @record_function('stopping_criteria') - def _batch_stopping_criteria(self, - token_ids: torch.Tensor, - stop_words: torch.Tensor, - num_appendable_ids: torch.Tensor, - dllm_mask: Optional[torch.Tensor] = None, - output_start_pos: Optional[torch.Tensor] = None, - inputs: Optional[ModelInputs] = None): - """Batched stopping criteria.""" - if self.model_config.model_paradigm == 'dllm': - assert dllm_mask is not None - return _batch_stopping_criteria_dllm(token_ids, stop_words, num_appendable_ids, dllm_mask, output_start_pos, - inputs) - - return _batch_stopping_criteria_default(token_ids, stop_words, num_appendable_ids) - async def _async_model_forward( self, inputs: ModelInputs, @@ -580,7 +434,8 @@ class _OutputGather: def __init__(self, max_seq_len): self._max_seq_len = max_seq_len self._start = 0 - self._output = None + self._output: torch.Tensor = None + self._device: torch.device = None def gather(self, output): """gather.""" @@ -595,6 +450,7 @@ def gather(self, output): seq_len = tmp_output.size(-2) if out_logits is None: out_logits = tmp_output.new_empty(1, self._max_seq_len, tmp_output.size(-1), device='cpu') + self._device = tmp_output.device out_logits[:, start:start + seq_len].copy_(tmp_output, non_blocking=True) self._start = start + seq_len self._output = out_logits @@ -604,7 +460,7 @@ def get_output(self): if not return_logits: return self._output[:, -1:] torch.cuda.synchronize() - return self._output + return self._output.to(self._device) __forward = self.async_forward @@ -652,15 +508,12 @@ async def __long_context_single_forward(new_inputs, max_seqlen: int): else: ret = await __long_context_single_forward(inputs, max_seqlen) - ret = self._postprocess_forward_output(ret, origin_inputs, is_long_context, return_logits) + if not return_logits: + ret = self._postprocess_forward_output(ret, origin_inputs) # compute dummy loop if dummy_loop > 0: - dummy_inputs = ModelInputs.make_dummy(1, - False, - 'cuda', - vocab_size=self.model_config.vocab_size, - build_ctx=self.build_model_ctx) + dummy_inputs = self.inputs_strategy.make_dummy(1, False, 'cuda', vocab_size=self.model_config.vocab_size) for _ in range(dummy_loop): await __forward(dummy_inputs) @@ -710,15 +563,6 @@ def _broadcast_next_token(self, next_token_ids: torch.Tensor, dist_ctx: DistCont yield handle.wait() - def dllm_unmasking(self, inputs: ModelInputs, logits: torch.Tensor, next_token_ids: torch.LongTensor, - dllm_mask: torch.Tensor): - """Unmasking dllm.""" - if self.build_model_ctx.model_paradigm != 'dllm': - return None, next_token_ids - input_ids = inputs.input_ids - input_ids = self._slice_outs(input_ids.flatten(), inputs.seq_length) - return self.unmasking_processor(logits, input_ids, next_token_ids, dllm_mask) - async def _async_step_background( self, inputs: ModelInputs, @@ -726,48 +570,25 @@ async def _async_step_background( swap_in_map: Dict = None, swap_out_map: Dict = None, sampling_inputs: SamplingInputs = None, - num_appendable_ids: torch.LongTensor = None, + stopping_criteria: StoppingCriteria = None, return_logits: bool = False, is_dummy: bool = False, sync_long_context: bool = False, - dllm_mask: torch.Tensor = None, - output_start_pos: torch.Tensor = None, + extra_inputs: ExtraInputs = None, ): """Asyc forward task.""" dist_ctx = get_dist_manager().current_context() - def __update_dllm(next_token_ids: torch.Tensor, dllm_mask: torch.Tensor, seqlens: torch.Tensor): - """Update token_ids and dllm_mask.""" - model_paradigm = self.model_config.model_paradigm - if model_paradigm != 'dllm': - return next_token_ids, dllm_mask, seqlens - - dllm_mask_token = self.model_config.dllm_mask_token - dllm_config = self.misc_config.dllm_config - dllm_block_length = dllm_config.dllm_block_length - - # reshape to (batch, dllm_block_length) - next_token_ids = next_token_ids.view(-1, dllm_block_length).clone() - dllm_mask = dllm_mask.view(-1, dllm_block_length).clone() - - # flags - is_cached = (dllm_mask == consts.DLLM_CACHED).all(dim=1) - - is_masked = (dllm_mask == consts.DLLM_MASKED) - next_token_ids[is_cached[:, None] | is_masked] = dllm_mask_token - dllm_mask[is_cached] = consts.DLLM_MASKED - seqlens = torch.where(is_cached.view(-1), seqlens, seqlens.new_zeros((1, ))) - - return next_token_ids.flatten(), dllm_mask.flatten(), seqlens - @record_function('update_inputs_for_next_step') - def __update_inputs(next_token_ids, model_metas, dllm_mask): + def __update_inputs(next_token_ids, model_metas, extra_inputs): """Update inputs.""" - inputs.model_metas = model_metas - next_token_ids, dllm_mask, step_seqlens = __update_dllm(next_token_ids, dllm_mask, inputs.seq_length) - inputs.step(next_token_ids, step_seqlens) - sampling_inputs.step(next_token_ids, dllm_mask=dllm_mask) - return next_token_ids, dllm_mask + return self.agent_strategy.update_inputs_for_next_step( + inputs, + sampling_inputs, + next_token_ids=next_token_ids, + model_metas=model_metas, + extra_inputs=extra_inputs, + ) @asynccontextmanager async def __prepare_dp(): @@ -863,8 +684,7 @@ async def __prepare_dp(): logits = output['logits'] logits = logits[0] # [bs, seq, prob] -> [seq, prob] last_logits = self._slice_outs(logits, inputs.seq_length) - if dllm_mask is not None: - dllm_mask = self._slice_outs(dllm_mask, inputs.seq_length) + extra_inputs = self.agent_strategy.slice_extra_inputs(extra_inputs, inputs.seq_length) # output empty for dummy inputs if is_dummy: @@ -881,35 +701,35 @@ async def __prepare_dp(): with self._broadcast_next_token(next_token_ids, dist_ctx, enable=need_broadcast_next): logger.debug(f' rank[{rank}]: synchornize token ids [{idx}]') - # unmasking - dllm_mask, next_token_ids = self.dllm_unmasking(inputs, last_logits, next_token_ids, dllm_mask) + # post sampling + next_token_ids, extra_inputs = self.agent_strategy.post_sampling( + inputs, last_logits, next_token_ids, extra_inputs) # stopping criteria - stopped, stop_pos, num_appendable_ids = self._batch_stopping_criteria( - next_token_ids, - sampling_inputs.stop_words, - num_appendable_ids, - dllm_mask=dllm_mask, - output_start_pos=output_start_pos, - inputs=inputs) + stopped, stop_pos, stopping_criteria = stopping_criteria.step(next_token_ids, + sampling_inputs.stop_words, + inputs=inputs, + extra_inputs=extra_inputs) else: # Avoid adding the ADInplaceOrView dispatch key to `next_token_ids`, # as it can trigger recompilation on different ranks when using torch.compile. with torch.inference_mode(): - next_token_ids = num_appendable_ids.new_zeros(last_logits.size(0)) + next_token_ids = inputs.input_ids.new_zeros(last_logits.size(0)) logprobs = None # broadcast next token for TP > 1 with self._broadcast_next_token(next_token_ids, dist_ctx, enable=need_broadcast_next): logger.debug(f' rank[{rank}]: synchornize token ids [{idx}]') - # unmasking - dllm_mask, next_token_ids = self.dllm_unmasking(inputs, last_logits, next_token_ids, dllm_mask) + # post sampling + next_token_ids, extra_inputs = self.agent_strategy.post_sampling(inputs, last_logits, next_token_ids, + extra_inputs) # send output model_metas = output.get('model_metas') if need_output: logger.debug(f' rank[{rank}]: Output [{idx}]') + extra_outputs = self.agent_strategy.make_extra_outputs(extra_inputs) self._push_output( BatchedOutputs(next_token_ids=next_token_ids, logits=logits if return_logits else None, @@ -917,11 +737,11 @@ async def __prepare_dp(): stop_pos=stop_pos, model_metas=model_metas, logprobs=logprobs, - dllm_mask=dllm_mask)) + extra_outputs=extra_outputs)) # update for next loop if is_decoding and idx < loop_count - 1: - next_token_ids, dllm_mask = __update_inputs(next_token_ids, model_metas, dllm_mask) + inputs, extra_inputs = __update_inputs(next_token_ids, model_metas, extra_inputs) async def _async_loop_background(self, forward_event: asyncio.Event = None): """Async loop background.""" @@ -949,7 +769,7 @@ async def _async_loop_background(self, forward_event: asyncio.Event = None): async def _async_loop_inputs_preprocess(self): """Async loop inputs preprocess.""" non_blocking = True - keys = ['inputs', 'sampling_inputs', 'num_appendable_ids', 'dllm_mask', 'output_start_pos'] + keys = ['inputs', 'sampling_inputs', 'stopping_criteria', 'extra_inputs'] while True: forward_inputs = await self._pre_in_que.get() @@ -1086,8 +906,8 @@ def _build_model(self): update_custom_module_map(custom_module_map) logger.debug(msg_with_rank(rank, 'build model.')) build_model_ctx = BuildModelContext(disable_vision_encoder=self.misc_config.disable_vision_encoder, - model_paradigm=self.model_config.model_paradigm, - dllm_config=self.misc_config.dllm_config) + dllm_config=self.misc_config.dllm_config, + strategy_factory=self.strategy_factory) patched_model = build_patched_model(self.model_config, device=device, model_format=self.misc_config.model_format, @@ -1255,6 +1075,7 @@ def __init__(self, model_agent: BaseModelAgent): self.model_config = model_agent.model_config self.cache_config = model_agent.cache_config self.misc_config = model_agent.misc_config + self.inputs_strategy = model_agent.inputs_strategy self.device = model_agent.device self._in_que = model_agent._in_que @@ -1270,11 +1091,10 @@ def _make_dummy_forward_inputs(self): dist_config = self.dist_ctx.dist_config batch_size = 2 if dist_config.enable_microbatch else 1 batch_size = min(self.cache_config.max_batches, batch_size) - model_inputs = ModelInputs.make_dummy(batch_size, - is_decoding, - device=self.device, - vocab_size=self.model_config.vocab_size, - build_ctx=self.model_agent.build_model_ctx) + model_inputs = self.inputs_strategy.make_dummy(batch_size, + is_decoding, + device=self.device, + vocab_size=self.model_config.vocab_size) forward_inputs = dict( inputs=model_inputs, loop_count=loop_count, diff --git a/lmdeploy/pytorch/messages.py b/lmdeploy/pytorch/messages.py index fcbb3ff200..b6fc2bee25 100644 --- a/lmdeploy/pytorch/messages.py +++ b/lmdeploy/pytorch/messages.py @@ -57,7 +57,7 @@ class SamplingParam: num_logprobs: int = -1 @classmethod - def from_gen_config(self, gen_config: GenerationConfig): + def from_gen_config(cls, gen_config: GenerationConfig): """From gen config.""" min_new_tokens = gen_config.min_new_tokens or 0 @@ -161,7 +161,7 @@ class MessageStatus(enum.Enum): class SequenceMeta: """Meta data shared by all sequence.""" block_size: int - model_paradigm: str = 'llm' + model_paradigm: str = 'ar' dllm_block_length: int = 1 dllm_mask_token: int = 151669 @@ -908,6 +908,6 @@ def update_token_ids(self, SEQ_CLS_MAP = dict( - llm=SchedulerSequenceDefault, + ar=SchedulerSequenceDefault, dllm=SchedulerSequenceDLLM, ) diff --git a/lmdeploy/pytorch/model_inputs.py b/lmdeploy/pytorch/model_inputs.py index 2d43b33a4e..a377c9d4d6 100644 --- a/lmdeploy/pytorch/model_inputs.py +++ b/lmdeploy/pytorch/model_inputs.py @@ -1,7 +1,7 @@ # Copyright (c) OpenMMLab. All rights reserved. from contextlib import contextmanager from dataclasses import dataclass, field, fields -from typing import Any, Dict, List, Literal +from typing import TYPE_CHECKING, Any, Dict, List, Literal import torch from torch.profiler import record_function @@ -12,6 +12,9 @@ from lmdeploy.pytorch.config import DLLMConfig, ModelConfig from lmdeploy.pytorch.multimodal.data_type import MultiModalTensor +if TYPE_CHECKING: + from lmdeploy.pytorch.strategies.base import StrategyFactoryBase + @dataclass class DPMeta: @@ -147,7 +150,7 @@ def step(self, input_ids: torch.LongTensor, step_seqlens: torch.Tensor = None): assert self.is_decoding if step_seqlens is None: step_seqlens = self.seq_length - self.history_lengths = self.history_lengths + step_seqlens + self.history_lengths += step_seqlens self.max_kv_seqlen += self.max_q_seqlen self.sum_kv_seqlen += self.max_kv_seqlen * self.seq_length.numel() if input_ids.dim() == 1: @@ -281,47 +284,6 @@ def build_dp_meta(self): """Build dp meta.""" self.dp_meta = DPMeta.build(self.input_ids.numel()) - @classmethod - @record_function('make_dummy_input') - def make_dummy(cls, - batch_size: int, - is_decoding: bool, - device: str = 'cpu', - dummy_block_id: int = 0, - vocab_size: int = 1, - build_ctx: 'BuildModelContext' = None): - """Make dummy inputs.""" - model_paradigm = build_ctx.model_paradigm - if model_paradigm == 'dllm': - block_size = build_ctx.dllm_config.dllm_block_length - max_q_seqlen = block_size - else: - max_q_seqlen = 1 - num_tokens = batch_size * max_q_seqlen - max_kv_seqlen = max_q_seqlen - input_ids = torch.randint(0, vocab_size, ( - 1, - num_tokens, - ), dtype=torch.long, device=device) - seq_length = torch.ones((batch_size, ), dtype=torch.long, device=device) - history_lengths = torch.zeros((batch_size, ), dtype=torch.long, device=device) - block_offsets = torch.full((batch_size, 1), dummy_block_id, dtype=torch.long, device=device) - num_ignored_history = torch.zeros((batch_size, ), dtype=torch.long, device=device) - local_adapter_ids = torch.zeros((batch_size, ), dtype=torch.long, device=device) - - return cls( - input_ids=input_ids, - seq_length=seq_length, - history_lengths=history_lengths, - block_offsets=block_offsets, - is_decoding=is_decoding, - num_ignored_history=num_ignored_history, - max_q_seqlen=max_q_seqlen, - max_kv_seqlen=max_kv_seqlen, - sum_kv_seqlen=batch_size, - local_adapter_ids=local_adapter_ids, - ) - def log_info(self): """Get log info.""" ret = (f'num_tokens={self.input_ids.numel()}, batch_size={self.seq_length.numel()}' @@ -475,8 +437,8 @@ def get_mask_and_position_ids(cls, inputs: ModelInputs): class BuildModelContext: """Context for building model.""" disable_vision_encoder: bool = False - model_paradigm: str = 'llm' dllm_config: DLLMConfig = None + strategy_factory: 'StrategyFactoryBase' = None class StepContextManager: diff --git a/lmdeploy/pytorch/strategies/__init__.py b/lmdeploy/pytorch/strategies/__init__.py new file mode 100644 index 0000000000..c0f5da1262 --- /dev/null +++ b/lmdeploy/pytorch/strategies/__init__.py @@ -0,0 +1,16 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from lmdeploy.pytorch.config import MiscConfig, ModelConfig + + +def build_strategy_factory(model_config: ModelConfig, misc_config: MiscConfig): + """Build strategy factory.""" + model_paradigm = model_config.model_paradigm + + if model_paradigm == 'ar': + from .ar import ARStrategyFactory + return ARStrategyFactory(model_config=model_config) + elif model_paradigm == 'dllm': + from .dllm import DLLMStrategyFactory + return DLLMStrategyFactory(model_config=model_config, dllm_config=misc_config.dllm_config) + else: + raise RuntimeError(f'Unsupported model paradigm: {model_paradigm}') diff --git a/lmdeploy/pytorch/strategies/ar/__init__.py b/lmdeploy/pytorch/strategies/ar/__init__.py new file mode 100644 index 0000000000..be68910e2e --- /dev/null +++ b/lmdeploy/pytorch/strategies/ar/__init__.py @@ -0,0 +1,49 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from typing import TYPE_CHECKING + +from lmdeploy.pytorch.config import ModelConfig + +if TYPE_CHECKING: + from lmdeploy.pytorch.strategies.base.cudagraph import CudagraphStrategy + from lmdeploy.pytorch.strategies.base.model_inputs import ModelInputsStrategy + from lmdeploy.pytorch.strategies.base.sampling import SamplingStrategy + from lmdeploy.pytorch.strategies.base.model_agent import ModelAgentStrategy + from lmdeploy.pytorch.strategies.base.engine import EngineStrategy + from lmdeploy.pytorch.config import CacheConfig, SchedulerConfig + +from ..base import StrategyFactoryBase + + +class ARStrategyFactory(StrategyFactoryBase): + + def __init__(self, model_config: ModelConfig): + """config.""" + self.model_config = model_config + + def build_cudagraph_strategy(self) -> 'CudagraphStrategy': + """Build cudagraph strategy.""" + from .cudagraph import ARCudagraphStrategy + return ARCudagraphStrategy() + + def build_sampling_strategy(self) -> 'SamplingStrategy': + """Build sampling strategy.""" + from .sampling import ARSamplingStrategy + pad_token_id = self.model_config.bos_token_id + pad_token_id = 0 if pad_token_id is None else pad_token_id + return ARSamplingStrategy(pad_token_id) + + def build_model_inputs_strategy(self) -> 'ModelInputsStrategy': + """Build model inputs strategy.""" + from .model_inputs import ARModelInputsStrategy + return ARModelInputsStrategy() + + def build_model_agent_strategy(self) -> 'ModelAgentStrategy': + """Build model agent strategy.""" + from .model_agent import ARModelAgentStrategy + return ARModelAgentStrategy() + + def build_engine_strategy(self, cache_config: 'CacheConfig', + scheduler_config: 'SchedulerConfig') -> 'EngineStrategy': + """Build engine strategy.""" + from .engine import AREngineStrategy + return AREngineStrategy(cache_config=cache_config, scheduler_config=scheduler_config) diff --git a/lmdeploy/pytorch/strategies/ar/cudagraph.py b/lmdeploy/pytorch/strategies/ar/cudagraph.py new file mode 100644 index 0000000000..e3749bcfc2 --- /dev/null +++ b/lmdeploy/pytorch/strategies/ar/cudagraph.py @@ -0,0 +1,9 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from ..base.cudagraph import CudagraphStrategy + + +class ARCudagraphStrategy(CudagraphStrategy): + + def get_max_tokens(self, batch_size: int) -> int: + """Get max tokens.""" + return batch_size diff --git a/lmdeploy/pytorch/strategies/ar/engine.py b/lmdeploy/pytorch/strategies/ar/engine.py new file mode 100644 index 0000000000..60ae69f8c9 --- /dev/null +++ b/lmdeploy/pytorch/strategies/ar/engine.py @@ -0,0 +1,20 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from lmdeploy.pytorch.config import CacheConfig, SchedulerConfig + +from ..base.engine import EngineStrategy + + +class AREngineStrategy(EngineStrategy): + """AR Engine Strategy.""" + + def __init__(self, scheduler_config: SchedulerConfig, cache_config: CacheConfig) -> None: + self.scheduler_config = scheduler_config + self.cache_config = cache_config + + def get_prealloc_size(self, is_decoding: bool): + """Get prealloc_size.""" + return self.scheduler_config.prefill_interval if is_decoding else 0 + + def get_num_loops(self, is_decoding: bool) -> int: + """Get num_loops.""" + return self.scheduler_config.prefill_interval if is_decoding else 1 diff --git a/lmdeploy/pytorch/strategies/ar/model_agent.py b/lmdeploy/pytorch/strategies/ar/model_agent.py new file mode 100644 index 0000000000..4096db2cb7 --- /dev/null +++ b/lmdeploy/pytorch/strategies/ar/model_agent.py @@ -0,0 +1,108 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from dataclasses import dataclass +from typing import Any, List, Optional + +import torch +from torch.profiler import record_function + +from lmdeploy.pytorch.engine.logits_process import SamplingInputs +from lmdeploy.pytorch.messages import SchedulerSequence +from lmdeploy.pytorch.model_inputs import ModelInputs + +from ..base.model_agent import ExtraInputs, ExtraOutputs, ModelAgentStrategy, StoppingCriteria + +SeqList = List[SchedulerSequence] + + +class ARExtraInputs(ExtraInputs): + """Ar extra inputs.""" + + +class ARExtraOutputs(ExtraOutputs): + """Ar extra outputs.""" + + +@dataclass +class ARStoppingCriteria(StoppingCriteria): + num_appendable_ids: torch.Tensor + + @record_function('stopping_criteria') + def step(self, + token_ids: torch.Tensor, + stop_words: torch.Tensor, + inputs: Optional[ModelInputs] = None, + extra_inputs: Optional[ARExtraInputs] = None): + """Check whether to stop generation.""" + num_appendable_ids = self.num_appendable_ids - 1 + stopped = num_appendable_ids <= 0 + stop_pos = torch.zeros_like(num_appendable_ids) + if stop_words is not None: + sw_stopped = (token_ids[:, None] == stop_words).any(1) + stopped = stopped | sw_stopped + one_ids = torch.clamp_max(num_appendable_ids, 0) + num_appendable_ids = torch.where(sw_stopped, one_ids, num_appendable_ids) + + # I don't know why assign inplace does not works... + new_stopping = ARStoppingCriteria(num_appendable_ids=num_appendable_ids) + return stopped, stop_pos, new_stopping + + +class ARModelAgentStrategy(ModelAgentStrategy): + + def slice_outputs(self, inputs: torch.Tensor, seq_length: torch.LongTensor) -> torch.Tensor: + """Slice outputs.""" + # batch size == 1 + if len(seq_length) == 1: + return inputs[-1:] + + if len(seq_length) == inputs.size(0): + return inputs + last_idx = seq_length.cumsum(-1) - 1 + return inputs[last_idx] + + def slice_extra_inputs(self, extra_inputs: ARExtraInputs, seq_length: torch.LongTensor) -> ARExtraInputs: + """Slice outputs.""" + return extra_inputs + + def _step_sampling_inputs(self, sampling_inputs: SamplingInputs, next_token_ids: torch.Tensor): + """step.""" + sampling_inputs.num_ignore_eos = sampling_inputs.num_ignore_eos - 1 + + all_ids = sampling_inputs.all_ids + if all_ids is not None: + sampling_inputs.all_ids = torch.cat([all_ids, next_token_ids[:, None]], 1) + + guided_input_ids = sampling_inputs.guided_input_ids + if guided_input_ids is not None: + sampling_inputs.guided_input_ids = torch.cat([guided_input_ids, next_token_ids[:, None]], 1) + + return sampling_inputs + + def make_stopping_criteria(self, seqs: SeqList) -> ARStoppingCriteria: + """Create stopping criteria.""" + num_appendable = [seq.sampling_param.max_new_tokens - seq.num_new_tokens for seq in seqs] + num_appendable = torch.tensor(num_appendable) + return ARStoppingCriteria(num_appendable_ids=num_appendable) + + def make_extra_inputs(self, seqs: 'SeqList') -> ExtraInputs: + """Create extra inputs.""" + return ARExtraInputs() + + def make_extra_outputs(self, extra_inputs: ARExtraInputs) -> ARExtraOutputs: + """Create extra outputs.""" + return ARExtraOutputs() + + def update_inputs_for_next_step(self, model_inputs: 'ModelInputs', sampling_inputs: 'SamplingInputs', + next_token_ids: torch.Tensor, model_metas: Any, extra_inputs: ARExtraInputs, + **kwargs): + """Step next inputs.""" + model_inputs.model_metas = model_metas + step_seqlens = model_inputs.seq_length + model_inputs.step(next_token_ids, step_seqlens) + self._step_sampling_inputs(sampling_inputs, next_token_ids) + return model_inputs, extra_inputs + + def post_sampling(self, inputs: 'ModelInputs', logits: torch.Tensor, next_token_ids: torch.LongTensor, + extra_inputs: ARExtraInputs): + """Post sampling.""" + return next_token_ids, extra_inputs diff --git a/lmdeploy/pytorch/strategies/ar/model_inputs.py b/lmdeploy/pytorch/strategies/ar/model_inputs.py new file mode 100644 index 0000000000..5b263b1e37 --- /dev/null +++ b/lmdeploy/pytorch/strategies/ar/model_inputs.py @@ -0,0 +1,21 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from lmdeploy.pytorch.model_inputs import ModelInputs + +from ..base.model_inputs import ModelInputsStrategy, make_dummy_inputs + + +class ARModelInputsStrategy(ModelInputsStrategy): + + def make_dummy(self, + batch_size: int, + is_decoding: bool, + device: str = 'cpu', + dummy_block_id: int = 0, + vocab_size: int = 1) -> ModelInputs: + """Create dummy model inputs.""" + return make_dummy_inputs(batch_size, + max_q_seqlen=1, + is_decoding=is_decoding, + device=device, + dummy_block_id=dummy_block_id, + vocab_size=vocab_size) diff --git a/lmdeploy/pytorch/strategies/ar/sampling.py b/lmdeploy/pytorch/strategies/ar/sampling.py new file mode 100644 index 0000000000..b2516f091a --- /dev/null +++ b/lmdeploy/pytorch/strategies/ar/sampling.py @@ -0,0 +1,191 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from typing import List + +import torch + +from lmdeploy.pytorch.engine.logits_process import SamplingInputs +from lmdeploy.pytorch.messages import SchedulerSequence + +from ..base.sampling import SamplingStrategy + +SeqList = List[SchedulerSequence] + + +def _gather_all_ids(pad_id: int, seqs: SeqList, sampling_inputs: SamplingInputs): + """Gather history.""" + if sampling_inputs.repetition_penalty is None and not any(sampling_inputs.logits_processors): + return None + batch = len(seqs) + max_len = max(seq.num_valid_ids for seq in seqs) + output = torch.full((batch, max_len), pad_id, dtype=torch.int64) + for idx, seq in enumerate(seqs): + h_len = seq.num_valid_ids + if h_len == 0: + continue + h_ids = torch.from_numpy(seq.valid_ids) + output[idx, -h_len:] = h_ids + return output + + +def _gather_guided_input_ids(pad_id: int, seqs: SeqList, sampling_inputs: 'SamplingInputs'): + """Gather input ids for guided decode.""" + if not any(sampling_inputs.response_formats or ()): + return None + batch = len(seqs) + max_len = max(seq.num_new_tokens for seq in seqs) + output = torch.full((batch, max_len), pad_id, dtype=torch.int64) + for idx, seq in enumerate(seqs): + h_len = seq.num_new_tokens + if h_len == 0: + continue + h_ids = torch.from_numpy(seq.generated_ids) + output[idx, -h_len:] = h_ids + return output + + +def _get_num_ignore_eos(seqs: SeqList): + """Get num ignore eos.""" + ret = [seq.sampling_param.min_new_tokens - seq.num_new_tokens for seq in seqs] + return torch.tensor(ret) + + +class ARSamplingStrategy(SamplingStrategy): + """Sampling strategy for autoregressive models.""" + + def __init__(self, pad_token_id: int) -> None: + pad_token_id = 0 if pad_token_id is None else pad_token_id + self.pad_token_id = pad_token_id + + def make_sampling_inputs(self, seqs: SeqList) -> SamplingInputs: + """Create sampling inputs from the sequences.""" + batch_size = len(seqs) + temperature = [None] * batch_size + repetition_penalty = [None] * batch_size + top_k = [None] * batch_size + top_p = [None] * batch_size + min_p = [None] * batch_size + bad_words = [None] * batch_size + stop_words = [None] * batch_size + random_seeds = [torch.seed() & 0xffffffff] * batch_size + random_offsets = [None] * batch_size + response_formats = [None] * batch_size + logits_processors = [None] * batch_size + num_logprobs = [None] * batch_size + + def __gather_params(): + """Gather params.""" + for idx, seq in enumerate(seqs): + param = seq.sampling_param + temperature[idx] = param.temperature + repetition_penalty[idx] = param.repetition_penalty + top_k[idx] = param.top_k + top_p[idx] = param.top_p + min_p[idx] = param.min_p + random_offsets[idx] = seq.num_valid_ids + response_formats[idx] = param.response_format + if param.random_seed is not None: + random_seeds[idx] = param.random_seed & 0xffffffff + + bw = param.bad_words + sw = param.stop_words + if (not param.ignore_eos and seq.num_new_tokens < param.min_new_tokens): + bw = bw + sw + bad_words[idx] = bw + stop_words[idx] = sw + logits_processors[idx] = param.logits_processors + num_logprobs[idx] = param.num_logprobs + + def __get_topp(top_p): + """Get topp.""" + min_top_p = min(top_p) + if min_top_p == 1.0: + top_p = None + else: + top_p = torch.tensor(top_p) + return top_p, min_top_p + + def __get_minp(min_p): + """Get minp.""" + max_min_p = max(min_p) + if max_min_p == 0.0: + min_p = None + else: + min_p = torch.Tensor(min_p) + return min_p + + def __get_bad_words(bad_words): + """Get bad words.""" + max_bw_len = max(len(bw) for bw in bad_words) + if max_bw_len == 0: + return None, None + if all(len(bw) == max_bw_len for bw in bad_words): + ret = torch.tensor(bad_words) + mask = torch.ones_like(ret, dtype=bool) + return ret, mask + ret = torch.full((batch_size, max_bw_len), -1, dtype=torch.int64) + for idx, bw in enumerate(bad_words): + bw_len = len(bw) + if bw_len == 0: + continue + bw = ret.new_tensor(bw) + ret[idx, :bw_len] = bw + + mask = ret >= 0 + ret = ret.where(mask, 0) + return ret, mask + + __gather_params() + + if all(rp == 1.0 for rp in repetition_penalty): + repetition_penalty = None + else: + repetition_penalty = torch.tensor(repetition_penalty) + + temperature = torch.tensor(temperature) + + bad_words, bad_mask = __get_bad_words(bad_words) + stop_words, stop_mask = __get_bad_words(stop_words) + + max_top_k = max(top_k) + if min(top_k) <= 0: + max_top_k = 0 + if max_top_k == 1: + top_k = None + top_p, min_top_p = None, 1.0 + min_p = None + random_seeds = None + random_offsets = None + else: + top_k = torch.tensor(top_k) + top_p, min_top_p = __get_topp(top_p) + min_p = __get_minp(min_p) + random_seeds = torch.tensor(random_seeds) + random_offsets = torch.tensor(random_offsets) + + max_num_logprobs = max(num_logprobs) + + sampling_input = SamplingInputs( + temperature=temperature, + bad_words=bad_words, + bad_mask=bad_mask, + stop_words=stop_words, + stop_mask=stop_mask, + repetition_penalty=repetition_penalty, + top_k=top_k, + top_p=top_p, + min_p=min_p, + random_seeds=random_seeds, + random_offsets=random_offsets, + response_formats=tuple(response_formats), + max_top_k=max_top_k, + min_top_p=min_top_p, + logits_processors=logits_processors, + max_num_logprobs=max_num_logprobs, + batch_size=batch_size, + ) + + pad_token_id = self.pad_token_id + sampling_input.all_ids = _gather_all_ids(pad_token_id, seqs, sampling_input) + sampling_input.guided_input_ids = _gather_guided_input_ids(pad_token_id, seqs, sampling_input) + sampling_input.num_ignore_eos = _get_num_ignore_eos(seqs) + return sampling_input diff --git a/lmdeploy/pytorch/strategies/base/__init__.py b/lmdeploy/pytorch/strategies/base/__init__.py new file mode 100644 index 0000000000..030feb3e7e --- /dev/null +++ b/lmdeploy/pytorch/strategies/base/__init__.py @@ -0,0 +1,42 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from abc import ABC, abstractmethod +from typing import TYPE_CHECKING + +if TYPE_CHECKING: + from lmdeploy.pytorch.config import CacheConfig, SchedulerConfig + + from .cudagraph import CudagraphStrategy + from .engine import EngineStrategy + from .model_agent import ModelAgentStrategy + from .model_inputs import ModelInputsStrategy + from .sampling import SamplingStrategy + + +class StrategyFactoryBase(ABC): + pass + + @abstractmethod + def build_cudagraph_strategy(self) -> 'CudagraphStrategy': + """Build cudagraph strategy.""" + pass + + @abstractmethod + def build_sampling_strategy(self) -> 'SamplingStrategy': + """Build sampling strategy.""" + pass + + @abstractmethod + def build_model_inputs_strategy(self) -> 'ModelInputsStrategy': + """Build model inputs strategy.""" + pass + + @abstractmethod + def build_model_agent_strategy(self) -> 'ModelAgentStrategy': + """Build model agent strategy.""" + pass + + @abstractmethod + def build_engine_strategy(self, cache_config: 'CacheConfig', + scheduler_config: 'SchedulerConfig') -> 'EngineStrategy': + """Build engine strategy.""" + pass diff --git a/lmdeploy/pytorch/strategies/base/cudagraph.py b/lmdeploy/pytorch/strategies/base/cudagraph.py new file mode 100644 index 0000000000..795c3b5350 --- /dev/null +++ b/lmdeploy/pytorch/strategies/base/cudagraph.py @@ -0,0 +1,10 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from abc import ABC, abstractmethod + + +class CudagraphStrategy(ABC): + + @abstractmethod + def get_max_tokens(self, batch_size: int) -> int: + """Get max tokens.""" + pass diff --git a/lmdeploy/pytorch/strategies/base/engine.py b/lmdeploy/pytorch/strategies/base/engine.py new file mode 100644 index 0000000000..1c8e0c4ef7 --- /dev/null +++ b/lmdeploy/pytorch/strategies/base/engine.py @@ -0,0 +1,16 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from abc import ABC, abstractmethod + + +class EngineStrategy(ABC): + """Engine strategy.""" + + @abstractmethod + def get_prealloc_size(self, is_decoding: bool) -> int: + """Get prealloc_size.""" + pass + + @abstractmethod + def get_num_loops(self, is_decoding: bool) -> int: + """Get num_loops.""" + pass diff --git a/lmdeploy/pytorch/strategies/base/model_agent.py b/lmdeploy/pytorch/strategies/base/model_agent.py new file mode 100644 index 0000000000..8701a256fa --- /dev/null +++ b/lmdeploy/pytorch/strategies/base/model_agent.py @@ -0,0 +1,132 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from abc import ABC, abstractmethod +from dataclasses import dataclass, fields +from typing import TYPE_CHECKING, Any, List, Optional + +import numpy as np +import torch + +if TYPE_CHECKING: + from lmdeploy.pytorch.engine.logits_process import SamplingInputs + from lmdeploy.pytorch.messages import SchedulerSequence + from lmdeploy.pytorch.model_inputs import ModelInputs + SeqList = List[SchedulerSequence] + + +def to_device(self, device: str, non_blocking: bool = False): + """To device.""" + out_dict = dict() + for f in fields(self): + k = f.name + v = getattr(self, k) + if isinstance(v, torch.Tensor): + v = v.to(device, non_blocking=non_blocking) + out_dict[k] = v + + return type(self)(**out_dict) + + +@dataclass +class ExtraInputs(ABC): + + def to_device(self, device: str, non_blocking: bool = False): + """To device.""" + return to_device(self, device, non_blocking) + + +@dataclass +class ExtraOutputs(ABC): + + def to_device(self, device: str, non_blocking: bool = False): + """To device.""" + return to_device(self, device, non_blocking) + + def to_cpu(self): + """To cpu.""" + return self.to_device('cpu', non_blocking=False) + + def to_numpy(self): + """To numpy.""" + out = dict() + for f in fields(self): + k = f.name + v = getattr(self, k) + if isinstance(v, torch.Tensor) and v.dtype != torch.bfloat16: + v = v.detach().numpy() + elif hasattr(v, 'to_numpy'): + v = v.to_numpy() + out[k] = v + return type(self)(**out) + + def to_tensor(self): + """To tensor.""" + out = dict() + for f in fields(self): + k = f.name + v = getattr(self, k) + if isinstance(v, np.ndarray): + v = torch.from_numpy(v) + elif hasattr(v, 'to_tensor'): + v = v.to_tensor() + out[k] = v + return type(self)(**out) + + +@dataclass +class StoppingCriteria(ABC): + """Base class for stopping criteria.""" + + @abstractmethod + def step(self, + token_ids: torch.Tensor, + stop_words: torch.Tensor, + inputs: Optional['ModelInputs'] = None, + extra_inputs: Optional[ExtraInputs] = None): + """Check whether to stop generation.""" + pass + + def to_device(self, device: str, non_blocking: bool = False): + """To device.""" + return to_device(self, device, non_blocking) + + +class ModelAgentStrategy(ABC): + """Base class for model agent strategies.""" + + @abstractmethod + def slice_outputs(self, inputs: torch.Tensor, seq_length: torch.LongTensor) -> torch.Tensor: + """Slice outputs.""" + pass + + @abstractmethod + def slice_extra_inputs(self, extra_inputs: ExtraInputs, seq_length: torch.LongTensor) -> ExtraInputs: + """Slice outputs.""" + pass + + @abstractmethod + def make_stopping_criteria(self, seqs: 'SeqList') -> StoppingCriteria: + """Create stopping criteria.""" + pass + + @abstractmethod + def make_extra_inputs(self, seqs: 'SeqList') -> ExtraInputs: + """Create extra inputs.""" + pass + + @abstractmethod + def make_extra_outputs(self, extra_inputs: ExtraInputs) -> ExtraOutputs: + """Create extra outputs.""" + pass + + @abstractmethod + def update_inputs_for_next_step(self, model_inputs: 'ModelInputs', sampling_inputs: 'SamplingInputs', + next_token_ids: torch.Tensor, model_metas: Any, extra_inputs: ExtraInputs, + **kwargs): + """Step next inputs.""" + pass + + @abstractmethod + def post_sampling(self, inputs: 'ModelInputs', logits: torch.Tensor, next_token_ids: torch.LongTensor, + extra_inputs: ExtraInputs): + """Post sampling.""" + pass diff --git a/lmdeploy/pytorch/strategies/base/model_inputs.py b/lmdeploy/pytorch/strategies/base/model_inputs.py new file mode 100644 index 0000000000..a8b97c775e --- /dev/null +++ b/lmdeploy/pytorch/strategies/base/model_inputs.py @@ -0,0 +1,54 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from abc import ABC, abstractmethod + +import torch +from torch.profiler import record_function + +from lmdeploy.pytorch.model_inputs import ModelInputs + + +@record_function('make_dummy_input') +def make_dummy_inputs(batch_size: int, + max_q_seqlen: int, + is_decoding: bool, + device: str = 'cpu', + dummy_block_id: int = 0, + vocab_size: int = 1): + """Make dummy inputs global implement.""" + num_tokens = batch_size * max_q_seqlen + max_kv_seqlen = max_q_seqlen + input_ids = torch.randint(0, vocab_size, ( + 1, + num_tokens, + ), dtype=torch.long, device=device) + seq_length = torch.ones((batch_size, ), dtype=torch.long, device=device) + history_lengths = torch.zeros((batch_size, ), dtype=torch.long, device=device) + block_offsets = torch.full((batch_size, 1), dummy_block_id, dtype=torch.long, device=device) + num_ignored_history = torch.zeros((batch_size, ), dtype=torch.long, device=device) + local_adapter_ids = torch.zeros((batch_size, ), dtype=torch.long, device=device) + + return ModelInputs( + input_ids=input_ids, + seq_length=seq_length, + history_lengths=history_lengths, + block_offsets=block_offsets, + is_decoding=is_decoding, + num_ignored_history=num_ignored_history, + max_q_seqlen=max_q_seqlen, + max_kv_seqlen=max_kv_seqlen, + sum_kv_seqlen=batch_size, + local_adapter_ids=local_adapter_ids, + ) + + +class ModelInputsStrategy(ABC): + + @abstractmethod + def make_dummy(self, + batch_size: int, + is_decoding: bool, + device: str = 'cpu', + dummy_block_id: int = 0, + vocab_size: int = 1) -> ModelInputs: + """Create dummy model inputs.""" + pass diff --git a/lmdeploy/pytorch/strategies/base/sampling.py b/lmdeploy/pytorch/strategies/base/sampling.py new file mode 100644 index 0000000000..172454157b --- /dev/null +++ b/lmdeploy/pytorch/strategies/base/sampling.py @@ -0,0 +1,17 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from abc import ABC, abstractmethod +from typing import List + +from lmdeploy.pytorch.engine.logits_process import SamplingInputs +from lmdeploy.pytorch.messages import SchedulerSequence + +SeqList = List[SchedulerSequence] + + +class SamplingStrategy(ABC): + """Base class for sampling strategies.""" + + @abstractmethod + def make_sampling_inputs(self, seqs: SeqList) -> SamplingInputs: + """Create sampling inputs from the sequences.""" + pass diff --git a/lmdeploy/pytorch/strategies/dllm/__init__.py b/lmdeploy/pytorch/strategies/dllm/__init__.py new file mode 100644 index 0000000000..c1e98b20aa --- /dev/null +++ b/lmdeploy/pytorch/strategies/dllm/__init__.py @@ -0,0 +1,53 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from typing import TYPE_CHECKING + +from lmdeploy.pytorch.config import DLLMConfig, ModelConfig + +if TYPE_CHECKING: + from lmdeploy.pytorch.strategies.base.cudagraph import CudagraphStrategy + from lmdeploy.pytorch.strategies.base.model_inputs import ModelInputsStrategy + from lmdeploy.pytorch.strategies.base.sampling import SamplingStrategy + from lmdeploy.pytorch.strategies.base.model_agent import ModelAgentStrategy + from lmdeploy.pytorch.strategies.base.engine import EngineStrategy + from lmdeploy.pytorch.config import CacheConfig, SchedulerConfig + +from ..base import StrategyFactoryBase + + +class DLLMStrategyFactory(StrategyFactoryBase): + + def __init__(self, model_config: ModelConfig, dllm_config: DLLMConfig): + """config.""" + self.model_config = model_config + self.dllm_config = dllm_config + self.dllm_block_length = self.dllm_config.dllm_block_length + + def build_cudagraph_strategy(self) -> 'CudagraphStrategy': + """Build cudagraph strategy.""" + from .cudagraph import DLLMCudagraphStrategy + return DLLMCudagraphStrategy(block_size=self.dllm_block_length) + + def build_sampling_strategy(self) -> 'SamplingStrategy': + """Build sampling strategy.""" + from .sampling import DLLMSamplingStrategy + pad_token_id = self.model_config.bos_token_id + pad_token_id = 0 if pad_token_id is None else pad_token_id + return DLLMSamplingStrategy(pad_token_id, self.dllm_block_length) + + def build_model_inputs_strategy(self) -> 'ModelInputsStrategy': + """Build model inputs strategy.""" + from .model_inputs import DLLMModelInputsStrategy + return DLLMModelInputsStrategy(block_size=self.dllm_block_length) + + def build_model_agent_strategy(self) -> 'ModelAgentStrategy': + """Build model agent strategy.""" + from .model_agent import DLLMModelAgentStrategy + return DLLMModelAgentStrategy(dllm_config=self.dllm_config, dllm_mask_token=self.model_config.dllm_mask_token) + + def build_engine_strategy(self, cache_config: 'CacheConfig', + scheduler_config: 'SchedulerConfig') -> 'EngineStrategy': + """Build engine strategy.""" + from .engine import DLLMEngineStrategy + return DLLMEngineStrategy(cache_config=cache_config, + scheduler_config=scheduler_config, + dllm_block_length=self.dllm_block_length) diff --git a/lmdeploy/pytorch/strategies/dllm/cudagraph.py b/lmdeploy/pytorch/strategies/dllm/cudagraph.py new file mode 100644 index 0000000000..2e388b22de --- /dev/null +++ b/lmdeploy/pytorch/strategies/dllm/cudagraph.py @@ -0,0 +1,13 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from ..base.cudagraph import CudagraphStrategy + + +class DLLMCudagraphStrategy(CudagraphStrategy): + + def __init__(self, block_size: int) -> None: + super().__init__() + self.block_size = block_size + + def get_max_tokens(self, batch_size: int) -> int: + """Get max tokens.""" + return batch_size * self.block_size diff --git a/lmdeploy/pytorch/strategies/dllm/engine.py b/lmdeploy/pytorch/strategies/dllm/engine.py new file mode 100644 index 0000000000..995bb49de3 --- /dev/null +++ b/lmdeploy/pytorch/strategies/dllm/engine.py @@ -0,0 +1,35 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from functools import lru_cache + +from lmdeploy.pytorch.config import CacheConfig, SchedulerConfig + +from ..base.engine import EngineStrategy + + +class DLLMEngineStrategy(EngineStrategy): + """DLLM Engine Strategy.""" + + def __init__(self, scheduler_config: SchedulerConfig, cache_config: CacheConfig, dllm_block_length: int) -> None: + self.scheduler_config = scheduler_config + self.cache_config = cache_config + self.dllm_block_length = dllm_block_length + + @lru_cache(maxsize=2) + def get_prealloc_size(self, is_decoding: bool) -> int: + """Get prealloc_size.""" + if not is_decoding: + return 0 + block_size = self.cache_config.block_size + dllm_block_length = self.dllm_block_length + num_blocks = min(self.scheduler_config.prefill_interval // 2, block_size // dllm_block_length) + return num_blocks * dllm_block_length + + @lru_cache(maxsize=2) + def get_num_loops(self, is_decoding: bool) -> int: + """Get num_loops.""" + if not is_decoding: + return 1 + block_size = self.cache_config.block_size + dllm_block_length = self.dllm_block_length + num_blocks = min(self.scheduler_config.prefill_interval // 2, block_size // dllm_block_length) + return num_blocks * dllm_block_length diff --git a/lmdeploy/pytorch/strategies/dllm/model_agent.py b/lmdeploy/pytorch/strategies/dllm/model_agent.py new file mode 100644 index 0000000000..1f98bb5198 --- /dev/null +++ b/lmdeploy/pytorch/strategies/dllm/model_agent.py @@ -0,0 +1,218 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from dataclasses import dataclass +from typing import Any, List, Optional + +import numpy as np +import torch +from torch.profiler import record_function + +from lmdeploy.pytorch import consts +from lmdeploy.pytorch.config import DLLMConfig +from lmdeploy.pytorch.engine.logits_process import SamplingInputs +from lmdeploy.pytorch.messages import SchedulerSequence +from lmdeploy.pytorch.model_inputs import ModelInputs + +from ..base.model_agent import ExtraInputs, ExtraOutputs, ModelAgentStrategy, StoppingCriteria +from .unmasking import UnmaskingProcessor + +SeqList = List[SchedulerSequence] + + +@dataclass +class DLLMExtraInputs(ExtraInputs): + """DLLM extra inputs.""" + dllm_mask: torch.Tensor + + +@dataclass +class DLLMExtraOutputs(ExtraOutputs): + """Ar extra outputs.""" + dllm_mask: torch.Tensor + + +def _check_stopwords_dllm(token_ids: torch.Tensor, stop_words: torch.Tensor, is_unmasked: torch.Tensor, + stopped: torch.Tensor, stop_pos: torch.Tensor, num_appendable_ids: torch.Tensor, + output_start_pos: torch.Tensor, inputs: ModelInputs): + num_tokens = token_ids.size(0) + batch_size = num_appendable_ids.size(0) + block_size = num_tokens // batch_size + + # blocks might contain stop words in prev-round chat + # these stop words should be ignored + kv_seqlens = inputs.history_lengths + inputs.seq_length + ignore_pos = (output_start_pos - (kv_seqlens - block_size)).clamp_min(0) + ignore_range = torch.arange(0, block_size, dtype=ignore_pos.dtype, device=ignore_pos.device) + ignore_mask = (ignore_range[None, :] < ignore_pos[:, None]).flatten() + token_ids = token_ids.clone() + token_ids[ignore_mask] = -1 + + # find stop words + sw_stopped = (token_ids[:, None] == stop_words).any(1) + sw_stopped = sw_stopped.view(batch_size, block_size) + sw_stop_pos = sw_stopped.int().argmax(1) + + stop_pos = torch.where(stopped, stop_pos, sw_stop_pos) + sw_stopped = sw_stopped.any(dim=1) + sw_stopped = sw_stopped & is_unmasked + stopped = stopped | sw_stopped + + # update num_appendable_ids + one_ids = torch.clamp_max(num_appendable_ids, 0) + num_appendable_ids = torch.where(sw_stopped, one_ids, num_appendable_ids) + + return stopped, stop_pos, num_appendable_ids + + +@dataclass +class DLLMStoppingCriteria(StoppingCriteria): + num_appendable_ids: torch.Tensor + output_start_pos: torch.Tensor + + @record_function('stopping_criteria') + def step(self, + token_ids: torch.Tensor, + stop_words: torch.Tensor, + inputs: Optional[ModelInputs] = None, + extra_inputs: Optional[DLLMExtraInputs] = None): + """Check whether to stop generation.""" + num_appendable_ids = self.num_appendable_ids + output_start_pos = self.output_start_pos + num_tokens = token_ids.size(0) + batch_size = num_appendable_ids.size(0) + block_size = num_tokens // batch_size + + dllm_mask = extra_inputs.dllm_mask + dllm_mask = dllm_mask.view(batch_size, block_size) + is_unmasked = (dllm_mask == consts.DLLM_UNMASKED).all(dim=1) + + # check stop by num_new_tokens + num_appendable_ids -= is_unmasked * block_size + stopped = num_appendable_ids <= 0 + stop_pos = block_size - 1 + num_appendable_ids + + # check stop words + if stop_words is not None: + stopped, stop_pos, num_appendable_ids = _check_stopwords_dllm(token_ids, + stop_words, + is_unmasked, + stopped, + stop_pos, + num_appendable_ids, + output_start_pos=output_start_pos, + inputs=inputs) + + new_stopping = DLLMStoppingCriteria(num_appendable_ids=num_appendable_ids, output_start_pos=output_start_pos) + return stopped, stop_pos, new_stopping + + +class DLLMModelAgentStrategy(ModelAgentStrategy): + + def __init__(self, dllm_config: DLLMConfig, dllm_mask_token: int): + block_size = dllm_config.dllm_block_length + self.block_size = block_size + self.dllm_mask_token = dllm_mask_token + + self.unmasking_processor = UnmaskingProcessor(dllm_config=dllm_config) + + def _update_dllm(self, next_token_ids: torch.Tensor, dllm_mask: torch.Tensor, seqlens: torch.Tensor): + """Update token_ids and dllm_mask.""" + dllm_mask_token = self.dllm_mask_token + dllm_block_length = self.block_size + + # reshape to (batch, dllm_block_length) + next_token_ids = next_token_ids.view(-1, dllm_block_length).clone() + dllm_mask = dllm_mask.view(-1, dllm_block_length).clone() + + # flags + is_cached = (dllm_mask == consts.DLLM_CACHED).all(dim=1) + + is_masked = (dllm_mask == consts.DLLM_MASKED) + next_token_ids[is_cached[:, None] | is_masked] = dllm_mask_token + dllm_mask[is_cached] = consts.DLLM_MASKED + seqlens = torch.where(is_cached.view(-1), seqlens, seqlens.new_zeros((1, ))) + + return next_token_ids.flatten(), dllm_mask.flatten(), seqlens + + def slice_outputs(self, inputs: torch.Tensor, seq_length: torch.LongTensor) -> torch.Tensor: + """Slice outputs.""" + block_length = self.block_size + # batch size = 1 + if len(seq_length) == 1: + return inputs[-block_length:] + + if len(seq_length) * block_length == inputs.size(0): + return inputs + last_idx = seq_length.cumsum(0) + block_range = torch.arange(-block_length, 0, device=last_idx.device) + index = (last_idx[:, None] + block_range[None, :]).flatten() + inputs = inputs[index] + return inputs + + def slice_extra_inputs(self, extra_inputs: DLLMExtraInputs, seq_length: torch.LongTensor) -> DLLMExtraInputs: + """Slice outputs.""" + dllm_mask = self.slice_outputs(extra_inputs.dllm_mask, seq_length) + return DLLMExtraInputs(dllm_mask=dllm_mask) + + def _step_sampling_inputs(self, sampling_inputs: SamplingInputs, next_token_ids: torch.Tensor, + dllm_mask: torch.Tensor, **kwargs): + """step.""" + from lmdeploy.pytorch import consts + dllm_block_size = self.block_size + DLLM_UNMASKED = consts.DLLM_UNMASKED + is_unmasked = (dllm_mask == DLLM_UNMASKED).view(-1, dllm_block_size).all(dim=1, keepdim=True) + num_ignore_eos = sampling_inputs.num_ignore_eos.view(-1, dllm_block_size) + num_ignore_eos = torch.where(is_unmasked, num_ignore_eos - dllm_block_size, num_ignore_eos) + sampling_inputs.num_ignore_eos = num_ignore_eos.flatten() + return sampling_inputs + + def make_stopping_criteria(self, seqs: SeqList) -> DLLMStoppingCriteria: + """Create stopping criteria.""" + # num_appendable + num_appendable = [seq.sampling_param.max_new_tokens - seq.num_new_tokens for seq in seqs] + num_appendable = torch.tensor(num_appendable) + block_size = self.block_size + remain = [seq.num_valid_ids % block_size for seq in seqs] + num_appendable += torch.tensor(remain) + + # output_start_pos + pos = [seq.output_start_pos for seq in seqs] + output_start_pos = torch.tensor(pos) + + return DLLMStoppingCriteria(num_appendable_ids=num_appendable, output_start_pos=output_start_pos) + + def make_extra_inputs(self, seqs: 'SeqList') -> ExtraInputs: + """Create extra inputs.""" + dllm_masks = [seq.dllm_mask for seq in seqs] + dllm_masks = torch.as_tensor(np.concatenate(dllm_masks)) + return DLLMExtraInputs(dllm_mask=dllm_masks) + + def make_extra_outputs(self, extra_inputs: DLLMExtraInputs) -> DLLMExtraOutputs: + """Create extra outputs.""" + dllm_mask = extra_inputs.dllm_mask + return DLLMExtraOutputs(dllm_mask=dllm_mask) + + def update_inputs_for_next_step(self, model_inputs: 'ModelInputs', sampling_inputs: 'SamplingInputs', + next_token_ids: torch.Tensor, model_metas: Any, extra_inputs: DLLMExtraInputs, + **kwargs): + """Step next inputs.""" + model_inputs.model_metas = model_metas + dllm_mask = extra_inputs.dllm_mask + + next_token_ids, dllm_mask, step_seqlens = self._update_dllm(next_token_ids, dllm_mask, model_inputs.seq_length) + model_inputs.step(next_token_ids, step_seqlens) + self._step_sampling_inputs(sampling_inputs, next_token_ids, dllm_mask=dllm_mask) + + extra_inputs = DLLMExtraInputs(dllm_mask=dllm_mask) + return model_inputs, extra_inputs + + def post_sampling(self, inputs: 'ModelInputs', logits: torch.Tensor, next_token_ids: torch.LongTensor, + extra_inputs: DLLMExtraInputs): + """Post sampling.""" + dllm_mask = extra_inputs.dllm_mask + input_ids = inputs.input_ids + input_ids = self.slice_outputs(input_ids.flatten(), inputs.seq_length) + + dllm_mask, next_token_ids = self.unmasking_processor(logits, input_ids, next_token_ids, dllm_mask) + + extra_inputs.dllm_mask = dllm_mask + return next_token_ids, extra_inputs diff --git a/lmdeploy/pytorch/strategies/dllm/model_inputs.py b/lmdeploy/pytorch/strategies/dllm/model_inputs.py new file mode 100644 index 0000000000..f05ba415f2 --- /dev/null +++ b/lmdeploy/pytorch/strategies/dllm/model_inputs.py @@ -0,0 +1,24 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from lmdeploy.pytorch.model_inputs import ModelInputs + +from ..base.model_inputs import ModelInputsStrategy, make_dummy_inputs + + +class DLLMModelInputsStrategy(ModelInputsStrategy): + + def __init__(self, block_size: int): + self.block_size = block_size + + def make_dummy(self, + batch_size: int, + is_decoding: bool, + device: str = 'cpu', + dummy_block_id: int = 0, + vocab_size: int = 1) -> ModelInputs: + """Create dummy model inputs.""" + return make_dummy_inputs(batch_size, + max_q_seqlen=self.block_size, + is_decoding=is_decoding, + device=device, + dummy_block_id=dummy_block_id, + vocab_size=vocab_size) diff --git a/lmdeploy/pytorch/strategies/dllm/sampling.py b/lmdeploy/pytorch/strategies/dllm/sampling.py new file mode 100644 index 0000000000..2ad5d5ecd7 --- /dev/null +++ b/lmdeploy/pytorch/strategies/dllm/sampling.py @@ -0,0 +1,57 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from typing import List + +from lmdeploy.pytorch.engine.logits_process import SamplingInputs +from lmdeploy.pytorch.messages import SchedulerSequence + +from ..ar.sampling import ARSamplingStrategy + +SeqList = List[SchedulerSequence] + + +class DLLMSamplingStrategy(ARSamplingStrategy): + """Sampling strategy for autoregressive models.""" + + def __init__(self, pad_token_id: int, dllm_block_length: int) -> None: + super().__init__(pad_token_id) + self.dllm_block_length = dllm_block_length + + def make_sampling_inputs(self, seqs: SeqList) -> SamplingInputs: + """Create sampling inputs from the sequences.""" + out = super().make_sampling_inputs(seqs) + dllm_block_length = self.dllm_block_length + + # repeat tensor + update_attr_names = [ + 'temperature', + 'bad_words', + 'bad_mask', + 'stop_words', + 'stop_mask', + 'repetition_penalty', + 'top_k', + 'top_p', + 'min_p', + 'random_seeds', + 'random_offsets', + 'all_ids', + 'guided_input_ids', + 'num_ignore_eos', + ] + for name in update_attr_names: + attr = getattr(out, name) + if attr is None: + continue + repeats = (dllm_block_length, ) + (1, ) * (attr.dim()) + attr = attr[None].repeat(*repeats).flatten(0, 1) + setattr(out, name, attr) + + if len(out.response_formats) > 0: + new_resp_formats = [] + for resp in out.response_formats: + new_resp_formats += [resp] * dllm_block_length + out.response_formats = tuple(new_resp_formats) + + out.batch_size *= dllm_block_length + + return out diff --git a/lmdeploy/pytorch/engine/unmasking.py b/lmdeploy/pytorch/strategies/dllm/unmasking.py similarity index 100% rename from lmdeploy/pytorch/engine/unmasking.py rename to lmdeploy/pytorch/strategies/dllm/unmasking.py From de49bb5db56f09423642d9070943a5ad42527d8a Mon Sep 17 00:00:00 2001 From: grimoire Date: Fri, 5 Sep 2025 19:11:54 +0800 Subject: [PATCH 17/29] update seqs --- lmdeploy/pytorch/engine/engine.py | 66 +---- lmdeploy/pytorch/messages.py | 275 ++----------------- lmdeploy/pytorch/strategies/ar/__init__.py | 5 + lmdeploy/pytorch/strategies/ar/sequence.py | 113 ++++++++ lmdeploy/pytorch/strategies/base/__init__.py | 7 +- lmdeploy/pytorch/strategies/base/sequence.py | 30 ++ lmdeploy/pytorch/strategies/dllm/__init__.py | 6 + lmdeploy/pytorch/strategies/dllm/sequence.py | 248 +++++++++++++++++ tests/pytorch/paging/test_block_manager.py | 8 +- tests/pytorch/paging/test_block_trie.py | 4 +- tests/pytorch/paging/test_scheduler.py | 12 +- 11 files changed, 447 insertions(+), 327 deletions(-) create mode 100644 lmdeploy/pytorch/strategies/ar/sequence.py create mode 100644 lmdeploy/pytorch/strategies/base/sequence.py create mode 100644 lmdeploy/pytorch/strategies/dllm/sequence.py diff --git a/lmdeploy/pytorch/engine/engine.py b/lmdeploy/pytorch/engine/engine.py index 4b5f1ad733..3445b7bfe6 100644 --- a/lmdeploy/pytorch/engine/engine.py +++ b/lmdeploy/pytorch/engine/engine.py @@ -139,13 +139,10 @@ def _build_misc_config(engine_config: PytorchEngineConfig): return misc_config -def _build_seq_meta(cache_config: CacheConfig, model_config: ModelConfig, engine_config: PytorchEngineConfig): +def _build_seq_meta(cache_config: CacheConfig, strategy: Any): from lmdeploy.pytorch.messages import SequenceMeta - seq_meta = SequenceMeta(cache_config.block_size, - model_paradigm=model_config.model_paradigm, - dllm_block_length=engine_config.dllm_block_length, - dllm_mask_token=model_config.dllm_mask_token) + seq_meta = SequenceMeta(cache_config.block_size, strategy=strategy) return seq_meta @@ -389,11 +386,12 @@ def __init__(self, self.model_agent_strategy = self.strategy_factory.build_model_agent_strategy() self.engine_strategy = self.strategy_factory.build_engine_strategy(cache_config=cache_config, scheduler_config=scheduler_config) + self.seq_strategy = self.strategy_factory.build_sequence_strategy() self.input_processor = self.executor.get_input_processor() cache_config = self.executor.cache_config self.adapter_manager = self._build_adapter_manager(adapters) - self.seq_meta = _build_seq_meta(cache_config, self.model_config, engine_config) + self.seq_meta = _build_seq_meta(cache_config, strategy=self.seq_strategy) self.scheduler = Scheduler(scheduler_config, cache_config, seq_meta=self.seq_meta) # engine args @@ -807,60 +805,6 @@ def create_model_inputs(self, messages: SeqList, is_prefill: bool): return model_inputs - def _update_running_default(self, running: SeqList, next_token_ids: torch.Tensor, stopped: List[bool], - model_metas: List[Any], is_decoding: bool): - """Update running default.""" - next_token_ids = next_token_ids.numpy() - update_mode = UpdateTokenMode.DECODE if is_decoding else UpdateTokenMode.PREFILL - for token, msg, stop, model_meta in zip(next_token_ids, running, stopped, model_metas): - if msg.status != MessageStatus.LOCKED: - continue - update_token = token - - # fill token - msg.update_token_ids(update_token, model_meta=model_meta, mode=update_mode) - if stop: - msg.status = MessageStatus.TO_BE_MIGRATED if msg.preserve_cache else MessageStatus.STOPPED - - def _update_running_dllm(self, running: SeqList, next_token_ids: torch.Tensor, dllm_mask: torch.Tensor, - stopped: List[bool], model_metas: List[Any], is_decoding: bool, stop_pos: torch.Tensor): - batch_size = len(running) - next_token_ids = next_token_ids.view(batch_size, -1).numpy() - dllm_mask = dllm_mask.view(batch_size, -1).numpy() - stop_pos = stop_pos.tolist() - update_mode = UpdateTokenMode.DECODE if is_decoding else UpdateTokenMode.PREFILL - for idx, token in enumerate(next_token_ids): - msg = running[idx] - stop = stopped[idx] - model_meta = model_metas[idx] - mask = dllm_mask[idx] - if msg.status != MessageStatus.LOCKED: - continue - update_token = token - update_mask = mask - - # fill token - msg.update_token_ids(update_token, dllm_mask=update_mask, model_meta=model_meta, mode=update_mode) - if stop: - msg.set_stop_pos(stop_pos[idx]) - msg.status = MessageStatus.TO_BE_MIGRATED if msg.preserve_cache else MessageStatus.STOPPED - - def update_running(self, running: SeqList, batched_outputs: BatchedOutputs, is_decoding: bool): - """Update scheduler.""" - next_token_ids = batched_outputs.next_token_ids - stopped = batched_outputs.stopped - stopped = stopped.tolist() - model_metas = batched_outputs.model_metas - if model_metas is None: - model_metas = [None] * len(running) - if self.model_config.model_paradigm == 'dllm': - dllm_mask = batched_outputs.extra_outputs.dllm_mask - stop_pos = batched_outputs.stop_pos - return self._update_running_dllm(running, next_token_ids, dllm_mask, stopped, model_metas, is_decoding, - stop_pos) - else: - return self._update_running_default(running, next_token_ids, stopped, model_metas, is_decoding) - def update_running_migration(self, running: SeqList, next_token_ids: np.ndarray, stopped: torch.Tensor, model_metas: List[Dict[str, Any]]): """Update scheduler.""" @@ -891,7 +835,7 @@ def _make_infer_outputs( seq_length = [seq.num_token_ids for seq in running] is_run = [seq.status == MessageStatus.LOCKED for seq in running] - self.update_running(running, batched_outputs, is_decoding) + self.seq_strategy.update_running(running=running, batched_outputs=batched_outputs, is_decoding=is_decoding) # generate output outputs: Dict[int, InferOutput] = dict() diff --git a/lmdeploy/pytorch/messages.py b/lmdeploy/pytorch/messages.py index b6fc2bee25..450166b747 100644 --- a/lmdeploy/pytorch/messages.py +++ b/lmdeploy/pytorch/messages.py @@ -1,20 +1,21 @@ # Copyright (c) OpenMMLab. All rights reserved. import enum -import time from dataclasses import dataclass, field -from typing import Any, Dict, List, Optional +from typing import TYPE_CHECKING, Any, Dict, List, Optional import numpy as np from torch import Tensor from lmdeploy.messages import EngineEvent, EventType, GenerationConfig, LogitsProcessor -from lmdeploy.pytorch import consts from lmdeploy.pytorch.disagg.conn.protocol import MigrationRequest from lmdeploy.pytorch.multimodal.data_type import MultiModalInputs from lmdeploy.utils import get_logger from .block import LogicalTokenBlocks +if TYPE_CHECKING: + from lmdeploy.pytorch.strategies.base.sequence import SequenceStrategy + logger = get_logger('lmdeploy') # vlm input type from pipeline @@ -161,9 +162,7 @@ class MessageStatus(enum.Enum): class SequenceMeta: """Meta data shared by all sequence.""" block_size: int - model_paradigm: str = 'ar' - dllm_block_length: int = 1 - dllm_mask_token: int = 151669 + strategy: 'SequenceStrategy' = None class SequenceManager: @@ -225,12 +224,6 @@ def update_sequence_status(self, seq: 'SchedulerSequence', new_status: MessageSt new_status_map[seq_id] = seq -DLLM_MASKED = consts.DLLM_MASKED -DLLM_UNMASKED = consts.DLLM_UNMASKED -DLLM_CACHED = consts.DLLM_CACHED -DLLM_MASK_DTYPE = np.uint8 - - def _to_ndarray(token_ids) -> np.ndarray: """To ndarray.""" if isinstance(token_ids, Tensor): @@ -266,18 +259,13 @@ def add_sequence(self, sampling_param = SamplingParam() seq_id = self.seq_manager._new_seq_id() - seq_cls = SEQ_CLS_MAP[self.seq_meta.model_paradigm] - seq = seq_cls( - seq_id=seq_id, - session=self, - num_new_tokens=0, - sampling_param=sampling_param, - adapter_name=adapter_name, - arrive_time=time.perf_counter(), - migration_request=migration_request, - resp_cache=resp_cache, - preserve_cache=preserve_cache, - ) + seq = self.seq_meta.strategy.make_sequence(seq_id=seq_id, + session=self, + sampling_param=sampling_param, + adapter_name=adapter_name, + migration_request=migration_request, + resp_cache=resp_cache, + preserve_cache=preserve_cache) seq.update_token_ids( token_ids, multimodals=multimodals, @@ -414,12 +402,6 @@ def copy(self): return self.clone() -class HistoryDLLMMask(HistoryTokenIds): - - def __init__(self, token_ids: np.ndarray = None, dtype: np.dtype = DLLM_MASK_DTYPE): - super().__init__(token_ids=token_ids, dtype=dtype) - - class HistoryMultiModals: def __init__(self, multimodals: MultiModalInputs = None): @@ -651,24 +633,6 @@ def record_event( ) -> None: self.engine_events.append(EngineEvent.new_event(event_type, timestamp)) - def update_token_ids(self, - token_ids: Tensor, - multimodals: MultiModalInputs = None, - embeddings: List[InputEmbeddings] = None, - model_meta: Dict[str, Any] = None, - mode: UpdateTokenMode = UpdateTokenMode.INPUTS, - **kwargs): - """Update token ids, old token ids will be added to history.""" - raise NotImplementedError('NotImplemented') - - def set_step(self, step: int): - """Set step.""" - raise NotImplementedError('NotImplemented') - - -@dataclass -class SchedulerSequenceDefault(SchedulerSequence): - def _update_embeddings(self, embeddings: List[InputEmbeddings]): """Update input embeddings.""" self._num_history_images += self._num_images @@ -679,7 +643,7 @@ def _update_embeddings(self, embeddings: List[InputEmbeddings]): self._num_images = len(new_embeddings) self.history_embeddings.append(new_embeddings) - def _update_multimodals(self, old_num_history_ids: int, multimodals: MultiModalInputs): + def _update_multimodals(self, multimodals: MultiModalInputs): """Update input multimodals.""" self._num_history_cross += self._num_cross if multimodals is None: @@ -689,7 +653,7 @@ def _update_multimodals(self, old_num_history_ids: int, multimodals: MultiModalI self.history_multimodals.add_inputs(multimodals) # for mllama - self._num_cross = self.history_multimodals.get_encoder_len(old_num_history_ids, self._num_history_ids) + self._num_cross = self.history_multimodals.get_encoder_len(self._num_history_ids, self._num_history_ids) def update_token_ids(self, token_ids: Tensor, @@ -699,215 +663,8 @@ def update_token_ids(self, mode: UpdateTokenMode = UpdateTokenMode.INPUTS, **kwargs): """Update token ids, old token ids will be added to history.""" - # update history image nums - self._update_embeddings(embeddings) - - # update multimodals - self._update_multimodals(self._num_history_ids, multimodals) - - self.arrive_time = time.perf_counter() - - token_ids = _to_ndarray(token_ids) - - num_valid = len(token_ids) - - if mode == UpdateTokenMode.INPUTS: - self.output_start_pos = self.num_all_ids + len(token_ids) - self._num_token_ids += num_valid - self.num_new_tokens = 0 - else: - self._num_history_ids += self._num_token_ids - num_token_ids = num_valid - self._num_token_ids = num_token_ids - self.num_new_tokens += num_token_ids - - self.history_cache.append(token_ids) - - if model_meta is not None: - self.model_meta = model_meta + raise NotImplementedError('NotImplemented') def set_step(self, step: int): """Set step.""" - num_all_ids = self.num_all_ids - # update step for vlm - if len(self.history_embeddings) > 0: - new_step, self._num_history_images, self._num_images = \ - self.history_embeddings.get_step(step) - assert 0 <= new_step <= step - step = new_step - self._num_history_ids = step - self._num_token_ids = num_all_ids - step - self.num_ignored_history = min(step, self.num_ignored_history) - - self.model_meta = None - - # cross - if self.history_multimodals is not None: - self._num_history_cross = self.history_multimodals.get_encoder_len(0, self.num_history_ids) - self._num_cross = self.history_multimodals.get_encoder_len(self._num_history_ids, num_all_ids) - - -@dataclass -class SchedulerSequenceDLLM(SchedulerSequenceDefault): - - # For dllm - history_dllm_mask: HistoryDLLMMask = field(default_factory=HistoryDLLMMask) - - def __post_init__(self): - """Post init.""" - super().__post_init__() - self._num_valid_ids: int = len(self.history_cache) - - @property - def dllm_mask(self): - start = self.num_history_ids - end = start + self._num_token_ids - return self.history_dllm_mask._token_ids[start:end] - - @property - def num_valid_ids(self): - return self._num_valid_ids - - @property - def generated_ids(self) -> np.ndarray: - end = self.num_valid_ids - start = end - self.num_new_tokens - return self.history_cache._token_ids[start:end] - - @property - def all_dllm_mask(self): - return self.history_dllm_mask._token_ids[:self.num_all_ids] - - @property - def dllm_block_length(self): - return self._seq_meta.dllm_block_length - - def set_stop_pos(self, pos: int): - dllm_block_length = self.dllm_block_length - val = dllm_block_length - pos - 1 - self._num_valid_ids -= val - self.num_new_tokens -= val - - def _update_token_ids_inputs(self, token_ids: np.ndarray, dllm_mask: np.ndarray): - """Append tokens.""" - num_tokens = len(token_ids) - dllm_block_length = self.dllm_block_length - dllm_mask_token = self._seq_meta.dllm_mask_token - new_token_ids = [token_ids] - new_dllm_mask = [dllm_mask] - - # add uncached tokens in token_ids - # for example, [cccc cccc uumm], the [uu] in last block is remain valid. - num_remain_valid = self.num_valid_ids - self.num_history_ids - if num_remain_valid != 0: - prev_token_ids = self.valid_ids[-num_remain_valid:] - prev_dllm_mask = np.full_like(prev_token_ids, DLLM_UNMASKED, dtype=DLLM_MASK_DTYPE) - new_token_ids = [prev_token_ids] + new_token_ids - new_dllm_mask = [prev_dllm_mask] + new_dllm_mask - self.history_cache.resize(self.num_history_ids) - self.history_dllm_mask.resize(self.num_history_ids) - num_tokens += num_remain_valid - - # pad to align with dllm_block_length - num_pad = (-num_tokens) % dllm_block_length - if num_pad > 0: - pad_ids = np.full_like(token_ids, dllm_mask_token, shape=(num_pad, )) - pad_mask = np.full_like(dllm_mask, DLLM_MASKED, shape=(num_pad, )) - new_token_ids += [pad_ids] - new_dllm_mask += [pad_mask] - - token_ids = np.concatenate(new_token_ids) - dllm_mask = np.concatenate(new_dllm_mask) - - assert len(token_ids) % dllm_block_length == 0 - - self.history_cache.append(token_ids) - self.history_dllm_mask.append(dllm_mask) - self.output_start_pos = self._num_valid_ids + len(token_ids) - self._num_valid_ids = self.num_history_ids + num_tokens - self._num_token_ids = len(token_ids) - self.num_new_tokens = 0 - - def _update_token_ids_decode(self, token_ids: np.ndarray, dllm_mask: np.ndarray): - """Update token ids for decode.""" - num_tokens = len(token_ids) - dllm_block_length = self.dllm_block_length - dllm_mask_token = self._seq_meta.dllm_mask_token - assert num_tokens % dllm_block_length == 0 - num_history_ids = self.num_history_ids - - token_ids[dllm_mask == DLLM_MASKED] = dllm_mask_token - self.history_cache[num_history_ids:] = token_ids - self.history_dllm_mask[num_history_ids:] = dllm_mask - - # check if all blocks are cached - last_mask = dllm_mask[-dllm_block_length:] - is_unmasked = np.all(last_mask == DLLM_UNMASKED) - is_cached = np.all(last_mask == DLLM_CACHED) - - if is_unmasked: - num_new = dllm_block_length - self._num_valid_ids % dllm_block_length - self._num_valid_ids += num_new - self.num_new_tokens += num_new - - if is_cached: - # add new block - new_token_ids = np.full_like(token_ids, dllm_mask_token, shape=(dllm_block_length, )) - new_dllm_mask = np.full_like(dllm_mask, DLLM_MASKED, shape=(dllm_block_length, )) - self.history_cache.append(new_token_ids) - self.history_dllm_mask.append(new_dllm_mask) - self._num_history_ids += self._num_token_ids - self._num_token_ids = dllm_block_length - - def _update_token_ids_prefill(self, token_ids: np.ndarray, dllm_mask: np.ndarray): - """Update token ids for prefill.""" - dllm_block_length = self.dllm_block_length - num_history_ids = self.num_history_ids - - # fill input cache - if self.num_token_ids > dllm_block_length: - end = self.num_token_ids - dllm_block_length - self.history_dllm_mask[num_history_ids:end] = DLLM_CACHED - self._num_history_ids += end - self._num_token_ids -= end - - # decoding update - self._update_token_ids_decode(token_ids, dllm_mask) - - def update_token_ids(self, - token_ids: Tensor, - multimodals: MultiModalInputs = None, - embeddings: List[InputEmbeddings] = None, - model_meta: Dict[str, Any] = None, - dllm_mask: Tensor = None, - mode: UpdateTokenMode = UpdateTokenMode.INPUTS, - **kwargs): - """Update token ids, old token ids will be added to history.""" - # update history image nums - self._update_embeddings(embeddings) - - # update multimodals - self._update_multimodals(self._num_history_ids, multimodals) - - self.arrive_time = time.perf_counter() - - token_ids: np.ndarray = _to_ndarray(token_ids) - if dllm_mask is None: - dllm_mask = np.full_like(token_ids, DLLM_UNMASKED, dtype=DLLM_MASK_DTYPE) - dllm_mask: np.ndarray = _to_ndarray(dllm_mask) - - if mode == UpdateTokenMode.INPUTS: - self._update_token_ids_inputs(token_ids, dllm_mask) - elif mode == UpdateTokenMode.PREFILL: - self._update_token_ids_prefill(token_ids, dllm_mask) - else: - self._update_token_ids_decode(token_ids, dllm_mask) - - if model_meta is not None: - self.model_meta = model_meta - - -SEQ_CLS_MAP = dict( - ar=SchedulerSequenceDefault, - dllm=SchedulerSequenceDLLM, -) + raise NotImplementedError('NotImplemented') diff --git a/lmdeploy/pytorch/strategies/ar/__init__.py b/lmdeploy/pytorch/strategies/ar/__init__.py index be68910e2e..b593107c2e 100644 --- a/lmdeploy/pytorch/strategies/ar/__init__.py +++ b/lmdeploy/pytorch/strategies/ar/__init__.py @@ -2,6 +2,7 @@ from typing import TYPE_CHECKING from lmdeploy.pytorch.config import ModelConfig +from lmdeploy.pytorch.strategies.base.sequence import SequenceStrategy if TYPE_CHECKING: from lmdeploy.pytorch.strategies.base.cudagraph import CudagraphStrategy @@ -47,3 +48,7 @@ def build_engine_strategy(self, cache_config: 'CacheConfig', """Build engine strategy.""" from .engine import AREngineStrategy return AREngineStrategy(cache_config=cache_config, scheduler_config=scheduler_config) + + def build_sequence_strategy(self) -> SequenceStrategy: + from .sequence import ARSequenceStrategy + return ARSequenceStrategy() diff --git a/lmdeploy/pytorch/strategies/ar/sequence.py b/lmdeploy/pytorch/strategies/ar/sequence.py new file mode 100644 index 0000000000..91a3335f18 --- /dev/null +++ b/lmdeploy/pytorch/strategies/ar/sequence.py @@ -0,0 +1,113 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import time +from dataclasses import dataclass +from typing import Any, Dict, List, Optional + +from torch import Tensor + +from lmdeploy.pytorch.disagg.conn.protocol import MigrationRequest +from lmdeploy.pytorch.engine.model_agent import BatchedOutputs +from lmdeploy.pytorch.messages import (InputEmbeddings, MessageStatus, MultiModalInputs, SamplingParam, + SchedulerSequence, SchedulerSession, UpdateTokenMode, _to_ndarray) + +from ..base.sequence import SequenceStrategy + +SeqList = List[SchedulerSequence] + + +@dataclass +class SchedulerSequenceDefault(SchedulerSequence): + + def update_token_ids(self, + token_ids: Tensor, + multimodals: MultiModalInputs = None, + embeddings: List[InputEmbeddings] = None, + model_meta: Dict[str, Any] = None, + mode: UpdateTokenMode = UpdateTokenMode.INPUTS, + **kwargs): + """Update token ids, old token ids will be added to history.""" + # update history image nums + self._update_embeddings(embeddings) + + # update multimodals + self._update_multimodals(multimodals) + + token_ids = _to_ndarray(token_ids) + + num_valid = len(token_ids) + + if mode == UpdateTokenMode.INPUTS: + self.arrive_time = time.perf_counter() + self.output_start_pos = self.num_all_ids + len(token_ids) + self._num_token_ids += num_valid + self.num_new_tokens = 0 + else: + self._num_history_ids += self._num_token_ids + num_token_ids = num_valid + self._num_token_ids = num_token_ids + self.num_new_tokens += num_token_ids + + self.history_cache.append(token_ids) + + if model_meta is not None: + self.model_meta = model_meta + + def set_step(self, step: int): + """Set step.""" + num_all_ids = self.num_all_ids + # update step for vlm + if len(self.history_embeddings) > 0: + new_step, self._num_history_images, self._num_images = \ + self.history_embeddings.get_step(step) + assert 0 <= new_step <= step + step = new_step + self._num_history_ids = step + self._num_token_ids = num_all_ids - step + self.num_ignored_history = min(step, self.num_ignored_history) + + self.model_meta = None + + # cross + if self.history_multimodals is not None: + self._num_history_cross = self.history_multimodals.get_encoder_len(0, self.num_history_ids) + self._num_cross = self.history_multimodals.get_encoder_len(self._num_history_ids, num_all_ids) + + +class ARSequenceStrategy(SequenceStrategy): + + def make_sequence(self, + seq_id: int, + session: 'SchedulerSession', + sampling_param: 'SamplingParam' = None, + adapter_name: str = None, + migration_request: Optional[MigrationRequest] = None, + resp_cache: bool = False, + preserve_cache: bool = False) -> 'SchedulerSequence': + """Make sequence.""" + return SchedulerSequenceDefault(seq_id=seq_id, + session=session, + sampling_param=sampling_param, + adapter_name=adapter_name, + migration_request=migration_request, + resp_cache=resp_cache, + preserve_cache=preserve_cache) + + def update_running(self, running: SeqList, batched_outputs: BatchedOutputs, is_decoding: bool) -> None: + """Update running sequences.""" + next_token_ids = batched_outputs.next_token_ids + stopped = batched_outputs.stopped + stopped = stopped.tolist() + model_metas = batched_outputs.model_metas + if model_metas is None: + model_metas = [None] * len(running) + + next_token_ids = next_token_ids.numpy() + update_mode = UpdateTokenMode.DECODE if is_decoding else UpdateTokenMode.PREFILL + for token, msg, stop, model_meta in zip(next_token_ids, running, stopped, model_metas): + if msg.status != MessageStatus.LOCKED: + continue + + # fill token + msg.update_token_ids(token, model_meta=model_meta, mode=update_mode) + if stop: + msg.status = MessageStatus.TO_BE_MIGRATED if msg.preserve_cache else MessageStatus.STOPPED diff --git a/lmdeploy/pytorch/strategies/base/__init__.py b/lmdeploy/pytorch/strategies/base/__init__.py index 030feb3e7e..42519779ad 100644 --- a/lmdeploy/pytorch/strategies/base/__init__.py +++ b/lmdeploy/pytorch/strategies/base/__init__.py @@ -10,10 +10,10 @@ from .model_agent import ModelAgentStrategy from .model_inputs import ModelInputsStrategy from .sampling import SamplingStrategy + from .sequence import SequenceStrategy class StrategyFactoryBase(ABC): - pass @abstractmethod def build_cudagraph_strategy(self) -> 'CudagraphStrategy': @@ -40,3 +40,8 @@ def build_engine_strategy(self, cache_config: 'CacheConfig', scheduler_config: 'SchedulerConfig') -> 'EngineStrategy': """Build engine strategy.""" pass + + @abstractmethod + def build_sequence_strategy(self) -> 'SequenceStrategy': + """Build sequence strategy.""" + pass diff --git a/lmdeploy/pytorch/strategies/base/sequence.py b/lmdeploy/pytorch/strategies/base/sequence.py new file mode 100644 index 0000000000..408a3cc15e --- /dev/null +++ b/lmdeploy/pytorch/strategies/base/sequence.py @@ -0,0 +1,30 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from abc import ABC, abstractmethod +from typing import TYPE_CHECKING, List, Optional + +from lmdeploy.pytorch.disagg.conn.protocol import MigrationRequest + +if TYPE_CHECKING: + from lmdeploy.pytorch.engine.model_agent import BatchedOutputs + from lmdeploy.pytorch.messages import SamplingParam, SchedulerSequence, SchedulerSession + SeqList = List[SchedulerSequence] + + +class SequenceStrategy(ABC): + + @abstractmethod + def make_sequence(self, + seq_id: int, + session: 'SchedulerSession', + sampling_param: 'SamplingParam' = None, + adapter_name: str = None, + migration_request: Optional[MigrationRequest] = None, + resp_cache: bool = False, + preserve_cache: bool = False) -> 'SchedulerSequence': + """Make sequence.""" + pass + + @abstractmethod + def update_running(self, running: 'SeqList', batched_outputs: 'BatchedOutputs', is_decoding: bool) -> None: + """Update running sequences.""" + pass diff --git a/lmdeploy/pytorch/strategies/dllm/__init__.py b/lmdeploy/pytorch/strategies/dllm/__init__.py index c1e98b20aa..6b90bb6569 100644 --- a/lmdeploy/pytorch/strategies/dllm/__init__.py +++ b/lmdeploy/pytorch/strategies/dllm/__init__.py @@ -2,6 +2,7 @@ from typing import TYPE_CHECKING from lmdeploy.pytorch.config import DLLMConfig, ModelConfig +from lmdeploy.pytorch.strategies.base.sequence import SequenceStrategy if TYPE_CHECKING: from lmdeploy.pytorch.strategies.base.cudagraph import CudagraphStrategy @@ -51,3 +52,8 @@ def build_engine_strategy(self, cache_config: 'CacheConfig', return DLLMEngineStrategy(cache_config=cache_config, scheduler_config=scheduler_config, dllm_block_length=self.dllm_block_length) + + def build_sequence_strategy(self) -> SequenceStrategy: + from .sequence import DLLMSequenceStrategy + return DLLMSequenceStrategy(block_size=self.dllm_block_length, + dllm_mask_token=self.model_config.dllm_mask_token) diff --git a/lmdeploy/pytorch/strategies/dllm/sequence.py b/lmdeploy/pytorch/strategies/dllm/sequence.py new file mode 100644 index 0000000000..ab004a2b63 --- /dev/null +++ b/lmdeploy/pytorch/strategies/dllm/sequence.py @@ -0,0 +1,248 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import time +from dataclasses import dataclass, field +from typing import Any, Dict, List, Optional + +import numpy as np +from torch import Tensor + +from lmdeploy.pytorch import consts +from lmdeploy.pytorch.disagg.conn.protocol import MigrationRequest +from lmdeploy.pytorch.engine.model_agent import BatchedOutputs +from lmdeploy.pytorch.messages import (HistoryTokenIds, InputEmbeddings, MessageStatus, MultiModalInputs, SamplingParam, + SchedulerSession, UpdateTokenMode, _to_ndarray) + +from ..ar.sequence import SchedulerSequenceDefault +from ..base.sequence import SequenceStrategy + +SeqList = List['SchedulerSequenceDLLM'] + +DLLM_MASKED = consts.DLLM_MASKED +DLLM_UNMASKED = consts.DLLM_UNMASKED +DLLM_CACHED = consts.DLLM_CACHED +DLLM_MASK_DTYPE = np.uint8 + + +class HistoryDLLMMask(HistoryTokenIds): + + def __init__(self, token_ids: np.ndarray = None, dtype: np.dtype = DLLM_MASK_DTYPE): + super().__init__(token_ids=token_ids, dtype=dtype) + + +@dataclass +class SchedulerSequenceDLLM(SchedulerSequenceDefault): + + # For dllm + history_dllm_mask: HistoryDLLMMask = field(default_factory=HistoryDLLMMask) + + def __post_init__(self): + """Post init.""" + super().__post_init__() + self._num_valid_ids: int = len(self.history_cache) + self._strategy: DLLMSequenceStrategy = self._seq_meta.strategy + + @property + def dllm_mask(self): + start = self.num_history_ids + end = start + self._num_token_ids + return self.history_dllm_mask._token_ids[start:end] + + @property + def num_valid_ids(self): + return self._num_valid_ids + + @property + def generated_ids(self) -> np.ndarray: + end = self.num_valid_ids + start = end - self.num_new_tokens + return self.history_cache._token_ids[start:end] + + @property + def all_dllm_mask(self): + return self.history_dllm_mask._token_ids[:self.num_all_ids] + + @property + def dllm_block_length(self): + return self._strategy.block_size + + @property + def dllm_mask_token(self): + return self._strategy.dllm_mask_token + + def set_stop_pos(self, pos: int): + dllm_block_length = self.dllm_block_length + val = dllm_block_length - pos - 1 + self._num_valid_ids -= val + self.num_new_tokens -= val + + def _update_token_ids_inputs(self, token_ids: np.ndarray, dllm_mask: np.ndarray): + """Append tokens.""" + num_tokens = len(token_ids) + dllm_block_length = self.dllm_block_length + dllm_mask_token = self.dllm_mask_token + new_token_ids = [token_ids] + new_dllm_mask = [dllm_mask] + + # add uncached tokens in token_ids + # for example, [cccc cccc uumm], the [uu] in last block is remain valid. + num_remain_valid = self.num_valid_ids - self.num_history_ids + if num_remain_valid != 0: + prev_token_ids = self.valid_ids[-num_remain_valid:] + prev_dllm_mask = np.full_like(prev_token_ids, DLLM_UNMASKED, dtype=DLLM_MASK_DTYPE) + new_token_ids = [prev_token_ids] + new_token_ids + new_dllm_mask = [prev_dllm_mask] + new_dllm_mask + self.history_cache.resize(self.num_history_ids) + self.history_dllm_mask.resize(self.num_history_ids) + num_tokens += num_remain_valid + + # pad to align with dllm_block_length + num_pad = (-num_tokens) % dllm_block_length + if num_pad > 0: + pad_ids = np.full_like(token_ids, dllm_mask_token, shape=(num_pad, )) + pad_mask = np.full_like(dllm_mask, DLLM_MASKED, shape=(num_pad, )) + new_token_ids += [pad_ids] + new_dllm_mask += [pad_mask] + + token_ids = np.concatenate(new_token_ids) + dllm_mask = np.concatenate(new_dllm_mask) + + assert len(token_ids) % dllm_block_length == 0 + + self.history_cache.append(token_ids) + self.history_dllm_mask.append(dllm_mask) + self.output_start_pos = self._num_valid_ids + len(token_ids) + self._num_valid_ids = self.num_history_ids + num_tokens + self._num_token_ids = len(token_ids) + self.num_new_tokens = 0 + + def _update_token_ids_decode(self, token_ids: np.ndarray, dllm_mask: np.ndarray): + """Update token ids for decode.""" + num_tokens = len(token_ids) + dllm_block_length = self.dllm_block_length + dllm_mask_token = self.dllm_mask_token + assert num_tokens % dllm_block_length == 0 + num_history_ids = self.num_history_ids + + token_ids[dllm_mask == DLLM_MASKED] = dllm_mask_token + self.history_cache[num_history_ids:] = token_ids + self.history_dllm_mask[num_history_ids:] = dllm_mask + + # check if all blocks are cached + last_mask = dllm_mask[-dllm_block_length:] + is_unmasked = np.all(last_mask == DLLM_UNMASKED) + is_cached = np.all(last_mask == DLLM_CACHED) + + if is_unmasked: + num_new = dllm_block_length - self._num_valid_ids % dllm_block_length + self._num_valid_ids += num_new + self.num_new_tokens += num_new + + if is_cached: + # add new block + new_token_ids = np.full_like(token_ids, dllm_mask_token, shape=(dllm_block_length, )) + new_dllm_mask = np.full_like(dllm_mask, DLLM_MASKED, shape=(dllm_block_length, )) + self.history_cache.append(new_token_ids) + self.history_dllm_mask.append(new_dllm_mask) + self._num_history_ids += self._num_token_ids + self._num_token_ids = dllm_block_length + + def _update_token_ids_prefill(self, token_ids: np.ndarray, dllm_mask: np.ndarray): + """Update token ids for prefill.""" + dllm_block_length = self.dllm_block_length + num_history_ids = self.num_history_ids + + # fill input cache + if self.num_token_ids > dllm_block_length: + end = self.num_token_ids - dllm_block_length + self.history_dllm_mask[num_history_ids:end] = DLLM_CACHED + self._num_history_ids += end + self._num_token_ids -= end + + # decoding update + self._update_token_ids_decode(token_ids, dllm_mask) + + def update_token_ids(self, + token_ids: Tensor, + multimodals: MultiModalInputs = None, + embeddings: List[InputEmbeddings] = None, + model_meta: Dict[str, Any] = None, + dllm_mask: Tensor = None, + mode: UpdateTokenMode = UpdateTokenMode.INPUTS, + **kwargs): + """Update token ids, old token ids will be added to history.""" + # update history image nums + self._update_embeddings(embeddings) + + # update multimodals + self._update_multimodals(multimodals) + + self.arrive_time = time.perf_counter() + + token_ids: np.ndarray = _to_ndarray(token_ids) + if dllm_mask is None: + dllm_mask = np.full_like(token_ids, DLLM_UNMASKED, dtype=DLLM_MASK_DTYPE) + dllm_mask: np.ndarray = _to_ndarray(dllm_mask) + + if mode == UpdateTokenMode.INPUTS: + self._update_token_ids_inputs(token_ids, dllm_mask) + elif mode == UpdateTokenMode.PREFILL: + self._update_token_ids_prefill(token_ids, dllm_mask) + else: + self._update_token_ids_decode(token_ids, dllm_mask) + + if model_meta is not None: + self.model_meta = model_meta + + +class DLLMSequenceStrategy(SequenceStrategy): + + def __init__(self, block_size: int, dllm_mask_token: int) -> None: + self.block_size = block_size + self.dllm_mask_token = dllm_mask_token + + def make_sequence(self, + seq_id: int, + session: 'SchedulerSession', + sampling_param: 'SamplingParam' = None, + adapter_name: str = None, + migration_request: Optional[MigrationRequest] = None, + resp_cache: bool = False, + preserve_cache: bool = False) -> 'SchedulerSequenceDLLM': + """Make sequence.""" + return SchedulerSequenceDLLM(seq_id=seq_id, + session=session, + sampling_param=sampling_param, + adapter_name=adapter_name, + migration_request=migration_request, + resp_cache=resp_cache, + preserve_cache=preserve_cache) + + def update_running(self, running: SeqList, batched_outputs: BatchedOutputs, is_decoding: bool) -> None: + """Update running sequences.""" + next_token_ids = batched_outputs.next_token_ids + stopped = batched_outputs.stopped + stopped = stopped.tolist() + model_metas = batched_outputs.model_metas + if model_metas is None: + model_metas = [None] * len(running) + dllm_mask = batched_outputs.extra_outputs.dllm_mask + stop_pos = batched_outputs.stop_pos + + batch_size = len(running) + next_token_ids = next_token_ids.view(batch_size, -1).numpy() + dllm_mask = dllm_mask.view(batch_size, -1).numpy() + stop_pos = stop_pos.tolist() + update_mode = UpdateTokenMode.DECODE if is_decoding else UpdateTokenMode.PREFILL + for idx, token in enumerate(next_token_ids): + msg = running[idx] + stop = stopped[idx] + model_meta = model_metas[idx] + mask = dllm_mask[idx] + if msg.status != MessageStatus.LOCKED: + continue + + # fill token + msg.update_token_ids(token, dllm_mask=mask, model_meta=model_meta, mode=update_mode) + if stop: + msg.set_stop_pos(stop_pos[idx]) + msg.status = MessageStatus.TO_BE_MIGRATED if msg.preserve_cache else MessageStatus.STOPPED diff --git a/tests/pytorch/paging/test_block_manager.py b/tests/pytorch/paging/test_block_manager.py index 6f7583c9a9..f74b6548cf 100644 --- a/tests/pytorch/paging/test_block_manager.py +++ b/tests/pytorch/paging/test_block_manager.py @@ -91,7 +91,9 @@ def block_mgr(self, num_cpu_blocks, num_gpu_blocks): @pytest.fixture def seq_manager(self, block_size): - seq_meta = SequenceMeta(block_size) + from lmdeploy.pytorch.strategies.ar.sequence import ARSequenceStrategy + strategy = ARSequenceStrategy() + seq_meta = SequenceMeta(block_size, strategy=strategy) yield SequenceManager(seq_meta) def test_alloc(self, block_mgr, seq_manager, num_gpu_blocks): @@ -226,7 +228,9 @@ def num_gpu_blocks(self): @pytest.fixture def seq_manager(self, block_size): - seq_meta = SequenceMeta(block_size) + from lmdeploy.pytorch.strategies.ar.sequence import ARSequenceStrategy + strategy = ARSequenceStrategy() + seq_meta = SequenceMeta(block_size, strategy=strategy) yield SequenceManager(seq_meta) @pytest.fixture diff --git a/tests/pytorch/paging/test_block_trie.py b/tests/pytorch/paging/test_block_trie.py index a9a3f571e9..7d20c96dab 100644 --- a/tests/pytorch/paging/test_block_trie.py +++ b/tests/pytorch/paging/test_block_trie.py @@ -39,7 +39,9 @@ def block_trie(self, cache_config, block_mgr): @pytest.fixture def seq_manager(self, block_size): - seq_meta = SequenceMeta(block_size) + from lmdeploy.pytorch.strategies.ar.sequence import ARSequenceStrategy + strategy = ARSequenceStrategy() + seq_meta = SequenceMeta(block_size, strategy=strategy) yield SequenceManager(seq_meta) def test_allocate(self, block_trie, block_mgr, seq_manager): diff --git a/tests/pytorch/paging/test_scheduler.py b/tests/pytorch/paging/test_scheduler.py index f14ab8249e..a0acf5f054 100644 --- a/tests/pytorch/paging/test_scheduler.py +++ b/tests/pytorch/paging/test_scheduler.py @@ -2,7 +2,7 @@ import torch from lmdeploy.pytorch.config import CacheConfig, SchedulerConfig -from lmdeploy.pytorch.messages import MessageStatus +from lmdeploy.pytorch.messages import MessageStatus, SequenceMeta from lmdeploy.pytorch.paging.scheduler import Scheduler @@ -32,8 +32,14 @@ def scheduler_config(self): yield SchedulerConfig(max_batches=4, max_session_len=128, max_request_output_len=64, eviction_type='recompute') @pytest.fixture - def scheduler(self, cache_config, scheduler_config): - yield Scheduler(scheduler_config=scheduler_config, cache_config=cache_config) + def seq_meta(self, block_size): + from lmdeploy.pytorch.strategies.ar.sequence import ARSequenceStrategy + strategy = ARSequenceStrategy() + yield SequenceMeta(block_size, strategy=strategy) + + @pytest.fixture + def scheduler(self, cache_config, scheduler_config, seq_meta): + yield Scheduler(scheduler_config=scheduler_config, cache_config=cache_config, seq_meta=seq_meta) def test_schedule_base(self, scheduler, block_size, num_gpu_blocks): block_manager = scheduler.block_manager From 3890cfe0918ba6766047c740f784ad1d8c632c8f Mon Sep 17 00:00:00 2001 From: grimoire Date: Mon, 8 Sep 2025 13:56:41 +0800 Subject: [PATCH 18/29] add moe support --- README.md | 1 + README_ja.md | 1 + README_zh-CN.md | 1 + docs/en/supported_models/supported_models.md | 1 + .../supported_models/supported_models.md | 1 + lmdeploy/pytorch/configurations/sdar.py | 2 +- lmdeploy/pytorch/models/module_map.py | 1 + lmdeploy/pytorch/models/sdar_moe.py | 501 ++++++++++++++++++ 8 files changed, 508 insertions(+), 1 deletion(-) create mode 100644 lmdeploy/pytorch/models/sdar_moe.py diff --git a/README.md b/README.md index d504176dd0..87c6715546 100644 --- a/README.md +++ b/README.md @@ -149,6 +149,7 @@ LMDeploy is a toolkit for compressing, deploying, and serving LLM, developed by
  • Phi-3.5-MoE (16x3.8B)
  • Phi-4-mini (3.8B)
  • MiniCPM3 (4B)
  • +
  • SDAR (1.7B-30B)
  • diff --git a/README_ja.md b/README_ja.md index 009d6749ad..4fe5f679c4 100644 --- a/README_ja.md +++ b/README_ja.md @@ -141,6 +141,7 @@ LMDeploy TurboMindエンジンは卓越した推論能力を持ち、さまざ
  • Phi-3.5-MoE (16x3.8B)
  • Phi-4-mini (3.8B)
  • MiniCPM3 (4B)
  • +
  • SDAR (1.7B-30B)
  • diff --git a/README_zh-CN.md b/README_zh-CN.md index 788995f1e4..6a3d35f508 100644 --- a/README_zh-CN.md +++ b/README_zh-CN.md @@ -150,6 +150,7 @@ LMDeploy TurboMind 引擎拥有卓越的推理能力,在各种规模的模型
  • Phi-3.5-MoE (16x3.8B)
  • Phi-4-mini (3.8B)
  • MiniCPM3 (4B)
  • +
  • SDAR (1.7B-30B)
  • diff --git a/docs/en/supported_models/supported_models.md b/docs/en/supported_models/supported_models.md index 2766fdb83d..af7573b1b5 100644 --- a/docs/en/supported_models/supported_models.md +++ b/docs/en/supported_models/supported_models.md @@ -117,6 +117,7 @@ The following tables detail the models supported by LMDeploy's TurboMind engine | Phi-3.5-mini | 3.8B | LLM | Yes | Yes | No | - | - | | Phi-3.5-MoE | 16x3.8B | LLM | Yes | Yes | No | - | - | | Phi-3.5-vision | 4.2B | MLLM | Yes | Yes | No | - | - | +| SDAR | 1.7B-30B | LLM | Yes | Yes | No | - | - | ```{note} * [1] Currently Mono-InternVL does not support FP16 due to numerical instability. Please use BF16 instead. diff --git a/docs/zh_cn/supported_models/supported_models.md b/docs/zh_cn/supported_models/supported_models.md index 4d9998f2df..5f209cea4f 100644 --- a/docs/zh_cn/supported_models/supported_models.md +++ b/docs/zh_cn/supported_models/supported_models.md @@ -117,6 +117,7 @@ | Phi-3.5-mini | 3.8B | LLM | Yes | Yes | No | - | - | | Phi-3.5-MoE | 16x3.8B | LLM | Yes | Yes | No | - | - | | Phi-3.5-vision | 4.2B | MLLM | Yes | Yes | No | - | - | +| SDAR | 1.7B-30B | LLM | Yes | Yes | No | - | - | ```{note} * [1] 目前,Mono-InternVL不支持FP16,因为数值不稳定。请改用BF16 diff --git a/lmdeploy/pytorch/configurations/sdar.py b/lmdeploy/pytorch/configurations/sdar.py index fdf6353760..edf9cd3cad 100644 --- a/lmdeploy/pytorch/configurations/sdar.py +++ b/lmdeploy/pytorch/configurations/sdar.py @@ -7,7 +7,7 @@ class SDARModelConfigBuilder(AutoModelConfigBuilder): @classmethod def condition(cls, hf_config): """config.""" - return hf_config.model_type == 'sdar' + return hf_config.model_type in ['sdar', 'sdar_moe'] @classmethod def build(cls, hf_config, model_path: str = None, **kwargs): diff --git a/lmdeploy/pytorch/models/module_map.py b/lmdeploy/pytorch/models/module_map.py index 0e36035bc1..08ac1cd21d 100644 --- a/lmdeploy/pytorch/models/module_map.py +++ b/lmdeploy/pytorch/models/module_map.py @@ -223,6 +223,7 @@ # SDAR MODULE_MAP.update({ 'SDARForCausalLM': f'{LMDEPLOY_PYTORCH_MODEL_PATH}.sdar.SDARForCausalLM', + 'SDARMoeForCausalLM': f'{LMDEPLOY_PYTORCH_MODEL_PATH}.sdar_moe.SDARMoeForCausalLM', }) CUSTOM_MODULE_MAP = dict() diff --git a/lmdeploy/pytorch/models/sdar_moe.py b/lmdeploy/pytorch/models/sdar_moe.py new file mode 100644 index 0000000000..d7c569402a --- /dev/null +++ b/lmdeploy/pytorch/models/sdar_moe.py @@ -0,0 +1,501 @@ +# Copyright (c) OpenMMLab. All rights reserved. + +from typing import Any, Dict, Iterable, List, Optional, Tuple + +import torch +from torch import nn +from transformers.configuration_utils import PretrainedConfig + +from lmdeploy.pytorch.distributed import get_tp_world_rank +from lmdeploy.pytorch.model_inputs import StepContext, StepContextManager +from lmdeploy.pytorch.nn import ApplyRotaryEmb, Attention, RMSNorm, SiluAndMul, build_rotary_embedding_from_config +from lmdeploy.pytorch.nn.linear import (build_down_linear, build_gateup_linear, build_o_proj, build_qkv_proj, + build_rowwise_linear) +from lmdeploy.pytorch.nn.moe import SoftmaxTopK, build_fused_moe +from lmdeploy.pytorch.weight_loader.model_weight_loader import load_weight + +from .utils.cudagraph import CudaGraphMixin + + +class SDARMoeAttention(nn.Module): + """attention.""" + + def __init__(self, config: PretrainedConfig, dtype: torch.dtype = None, device: torch.device = None): + super().__init__() + quantization_config = getattr(config, 'quantization_config', None) + num_heads = config.num_attention_heads + num_key_value_heads = config.num_key_value_heads + hidden_size = config.hidden_size + head_dim = getattr(config, 'head_dim', hidden_size // num_heads) + num_replicate_kv_heads = getattr(config, 'num_replicate_key_value_heads', 1) + # packed qkv + # Qwen3 uses 'config.attention_bias = False' for q/k/o projections + self.qkv_proj = build_qkv_proj(hidden_size, + num_q_heads=num_heads, + num_kv_heads=num_key_value_heads, + head_size=head_dim, + bias=config.attention_bias, + quant_config=quantization_config, + dtype=dtype, + device=device, + num_replicate_kv_heads=num_replicate_kv_heads) + + # rotary embedding + self.apply_rotary_pos_emb = ApplyRotaryEmb() + dllm_block_length = config.dllm_block_length + + # attention + self.attn_fwd = Attention( + num_heads, + head_dim, + num_kv_heads=num_key_value_heads, + v_head_size=head_dim, + sliding_window=config.sliding_window, + block_sparse_size=dllm_block_length, + ) + + # o_proj + self.o_proj = build_o_proj(num_heads * head_dim, + hidden_size, + bias=config.attention_bias, + quant_config=quantization_config, + dtype=dtype, + device=device, + is_tp=True) + + # q, k norm + self.q_norm = RMSNorm(head_dim, config.rms_norm_eps, dtype=dtype, device=device) + self.k_norm = RMSNorm(head_dim, config.rms_norm_eps, dtype=dtype, device=device) + + def forward( + self, + hidden_states: torch.Tensor, + rotary_pos_emb: Tuple[torch.FloatTensor, torch.FloatTensor], + past_key_value: Optional[Tuple[torch.Tensor]] = None, + attn_metadata: Any = None, + ): + """Rewrite of LlamaAttention.forward.""" + # qkv proj + qkv_states = self.qkv_proj(hidden_states) + # (-1, heads, head_dim) + qkv_states = qkv_states.flatten(0, -2) + query_states, key_states, value_states = self.qkv_proj.split_qkv(qkv_states) + + # apply q, k norm + query_states = self.q_norm(query_states) + key_states = self.k_norm(key_states) + + # apply rotary embedding + cos, sin = rotary_pos_emb + query_states, key_states = self.apply_rotary_pos_emb( + query_states, + key_states, + cos, + sin, + ) + # attention + attn_output = self.attn_fwd( + query_states, + key_states, + value_states, + past_key_value[0], + past_key_value[1], + attn_metadata, + k_scales_zeros=None if len(past_key_value) == 2 else past_key_value[2], + v_scales_zeros=None if len(past_key_value) == 2 else past_key_value[3], + inplace=True, + ) + attn_output = attn_output.reshape(*hidden_states.shape[:-1], -1) + + # o proj + attn_output = self.o_proj(attn_output) + return attn_output + + +class SDARMoeMLP(nn.Module): + """mlp.""" + + def __init__(self, + config: PretrainedConfig, + intermediate_size: int, + dtype: torch.dtype = None, + device: torch.device = None): + super().__init__() + quantization_config = getattr(config, 'quantization_config', None) + # gate up + self.gate_up_proj = build_gateup_linear( + config.hidden_size, + [intermediate_size, intermediate_size], + bias=False, + dtype=dtype, + device=device, + quant_config=quantization_config, + is_tp=True, + ) + + # silu and mul + self.act_fn = SiluAndMul(inplace=True) + + # down + self.down_proj = build_down_linear(intermediate_size, + config.hidden_size, + bias=False, + quant_config=quantization_config, + dtype=dtype, + device=device, + is_tp=True) + + def forward(self, x): + """forward.""" + gate_up = self.gate_up_proj(x) + act = self.act_fn(gate_up) + return self.down_proj(act) + + +class SDARMoeSparseMoeBlock(nn.Module): + """SDARMoeSparseMoeBlock.""" + + def __init__(self, + config: PretrainedConfig, + layer_idx: int, + dtype: torch.dtype = None, + device: torch.device = None): + super().__init__() + self.config = config + quantization_config = getattr(config, 'quantization_config', None) + self.hidden_dim = config.hidden_size + self.ffn_dim = config.moe_intermediate_size + self.num_experts = config.num_experts + self.top_k = config.num_experts_per_tok + self.renormalize = config.norm_topk_prob + + self.gate = build_rowwise_linear( + self.hidden_dim, + self.num_experts, + bias=False, + dtype=dtype, + device=device, + is_tp=False, + ) + + self.softmax_topk = SoftmaxTopK(self.top_k) + + world_size, _ = get_tp_world_rank() + _all_reduce = world_size > 1 + self.experts = build_fused_moe( + self.hidden_dim, + self.ffn_dim, + self.num_experts, + top_k=self.top_k, + renormalize=self.renormalize, + dtype=dtype, + device=device, + quant_config=quantization_config, + all_reduce=_all_reduce, + layer_idx=layer_idx, + ) + + def forward(self, hidden_states: torch.Tensor): + """forward.""" + batch_size, sequence_length, hidden_dim = hidden_states.shape + hidden_states = hidden_states.view(-1, hidden_dim) + router_logits = self.gate(hidden_states) + topk_weights, topk_ids = self.softmax_topk(router_logits) + out_states = self.experts( + hidden_states, + topk_weights, + topk_ids, + ) + + out_states = out_states.reshape(batch_size, sequence_length, -1) + return out_states + + +class SDARMoeDecoderLayer(nn.Module): + """Decode layer.""" + + def __init__(self, + config: PretrainedConfig, + layer_idx: int, + dtype: torch.dtype = None, + device: torch.device = None): + super().__init__() + self.layer_idx = layer_idx + quantization_config = getattr(config, 'quantization_config', None) + + # build attention layer + self.self_attn = SDARMoeAttention(config, dtype=dtype, device=device) + + # build MLP + if (layer_idx not in config.mlp_only_layers) and (config.num_experts > 0 and + (layer_idx + 1) % config.decoder_sparse_step == 0): + self.mlp = SDARMoeSparseMoeBlock(config, layer_idx, dtype=dtype, device=device) + else: + self.mlp = SDARMoeMLP(config, intermediate_size=config.intermediate_size, dtype=dtype, device=device) + + # build input layer norm + self.input_layernorm = RMSNorm(config.hidden_size, + config.rms_norm_eps, + quant_config=quantization_config, + dtype=dtype, + device=device) + + # build attention layer norm + self.post_attention_layernorm = RMSNorm(config.hidden_size, + config.rms_norm_eps, + quant_config=quantization_config, + dtype=dtype, + device=device) + + def forward( + self, + hidden_states: torch.Tensor, + rotary_pos_emb: Tuple[torch.FloatTensor, torch.FloatTensor], + past_key_value: Optional[List[torch.FloatTensor]], + residual: Optional[torch.Tensor] = None, + attn_metadata: Any = None, + ): + + if residual is None: + residual = hidden_states + hidden_states = self.input_layernorm(hidden_states) + else: + hidden_states, residual = self.input_layernorm(hidden_states, residual) + + # Self Attention + hidden_states = self.self_attn( + hidden_states=hidden_states, + rotary_pos_emb=rotary_pos_emb, + past_key_value=past_key_value, + attn_metadata=attn_metadata, + ) + + # Fully Connected + hidden_states, residual = self.post_attention_layernorm(hidden_states, residual) + hidden_states = self.mlp(hidden_states) + + outputs = (hidden_states, residual) + return outputs + + +class SDARMoeModel(nn.Module): + """SDAR model.""" + + def __init__(self, config: PretrainedConfig, dtype: torch.dtype = None, device: torch.device = None): + super().__init__() + self.padding_idx = config.pad_token_id + self.vocab_size = config.vocab_size + + self.embed_tokens = nn.Embedding(config.vocab_size, + config.hidden_size, + self.padding_idx, + dtype=dtype, + device=device) + + # build all decode layers + self.layers = nn.ModuleList([ + SDARMoeDecoderLayer(config, layer_idx, dtype=dtype, device=device) + for layer_idx in range(config.num_hidden_layers) + ]) + + # build norm + self.norm = RMSNorm(config.hidden_size, config.rms_norm_eps, dtype=dtype, device=device) + + # build rotary embedding + self.rotary_emb = build_rotary_embedding_from_config(config) + + def forward( + self, + input_ids: torch.LongTensor = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[List[torch.FloatTensor]] = None, + attn_metadata: Any = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + ): + """Rewrite of forward.""" + + # token embedding + if inputs_embeds is None: + inputs_embeds = self.embed_tokens(input_ids) + + hidden_states = inputs_embeds + + # rotary embedding + cos, sin = self.rotary_emb(hidden_states, position_ids) + cos, sin = cos[0], sin[0] + rotary_pos_emb = (cos, sin) + + # decoding + residual = None + for idx, decoder_layer in enumerate(self.layers): + past_key_value = past_key_values[idx] + hidden_states, residual = decoder_layer( + hidden_states, + rotary_pos_emb=rotary_pos_emb, + past_key_value=past_key_value, + residual=residual, + attn_metadata=attn_metadata, + ) + + # norm + hidden_states, _ = self.norm(hidden_states, residual) + + return hidden_states + + def get_input_embeddings(self): + """Get input embeddings.""" + return self.embed_tokens + + +class SDARMoeForCausalLM(nn.Module, CudaGraphMixin): + """ModelForCausalLM.""" + + packed_modules_mapping = { + 'qkv_proj': [ + 'q_proj', + 'k_proj', + 'v_proj', + ], + 'gate_up_proj': [ + 'gate_proj', + 'up_proj', + ], + } + + def __init__(self, + config: PretrainedConfig, + ctx_mgr: StepContextManager, + dtype: torch.dtype = None, + device: torch.device = None): + super().__init__() + self.config = config + self.ctx_mgr = ctx_mgr + config.dllm_block_length = ctx_mgr.build_ctx.dllm_config.dllm_block_length + # build model + self.model = SDARMoeModel(config, dtype=dtype, device=device) + # build lm_head + self.lm_head = build_rowwise_linear(config.hidden_size, + config.vocab_size, + bias=False, + dtype=dtype, + device=device) + + def forward( + self, + input_ids: torch.Tensor, + position_ids: torch.Tensor, + past_key_values: List[List[torch.Tensor]], + attn_metadata: Any = None, + inputs_embeds: torch.Tensor = None, + **kwargs, + ): + """Model forward, return logits.""" + hidden_states = self.model( + input_ids=input_ids, + position_ids=position_ids, + past_key_values=past_key_values, + attn_metadata=attn_metadata, + inputs_embeds=inputs_embeds, + ) + return hidden_states + + def get_logits(self, hidden_states: torch.Tensor): + """Compute logits of the model output.""" + return self.lm_head(hidden_states) + + def update_weights(self): + """Update weights.""" + if self.config.tie_word_embeddings: + self.lm_head.weight = self.model.embed_tokens.weight + + def get_input_embeddings(self): + """Get input embeddings.""" + return self.model.get_input_embeddings() + + def prepare_inputs_for_generation( + self, + past_key_values: List[List[torch.Tensor]], + inputs_embeds: Optional[torch.Tensor] = None, + context: StepContext = None, + ): + """Prepare input.""" + # get input_ids, position_ids and attention metadatas + input_ids = context.input_ids + position_ids = context.position_ids + attn_metadata = context.attn_metadata + + # process vision embeddings + vision_embeddings = context.input_embeddings + vision_embedding_indexing = context.input_embedding_indexing + if vision_embeddings is not None and len(vision_embeddings) > 0: + if inputs_embeds is None: + inputs_embeds = self.get_input_embeddings()(input_ids) + inputs_embeds[:, vision_embedding_indexing, :] = vision_embeddings.to(inputs_embeds) + + # inputs of forward + return dict( + input_ids=input_ids, + position_ids=position_ids, + past_key_values=past_key_values, + attn_metadata=attn_metadata, + inputs_embeds=inputs_embeds, + ) + + def _load_weight_experts(self, name: str, loaded_weight: torch.Tensor, params_dict: Dict[str, nn.Parameter], + expert_params_mapping: List): + """Load weight experts.""" + # load fused weights + for (param_name, weight_name, expert_id, shard_id) in expert_params_mapping: + if weight_name not in name: + continue + name = name.replace(weight_name, param_name) + param = params_dict[name] + load_weight(param, loaded_weight, expert_id=expert_id, shard_id=shard_id) + break + else: + param = params_dict[name] + load_weight(param, loaded_weight) + + def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): + """Load weights.""" + # modify from vllm + stacked_params_mapping = [ + # (param_name, shard_name, shard_id) + ('.qkv_proj', '.q_proj', 'q'), + ('.qkv_proj', '.k_proj', 'k'), + ('.qkv_proj', '.v_proj', 'v'), + ('.gate_up_proj', '.gate_proj', 0), + ('.gate_up_proj', '.up_proj', 1), + ] + + # expert map + num_experts = self.config.num_experts + expert_params_mapping = [] + for exp_id in range(num_experts): + gate_param = ('.experts.gate_up', f'.experts.{exp_id}.gate_proj', exp_id, 'gate') + up_param = ('.experts.gate_up', f'.experts.{exp_id}.up_proj', exp_id, 'up') + down_param = ('.experts.down', f'.experts.{exp_id}.down_proj', exp_id, 'down') + expert_params_mapping += [gate_param, up_param, down_param] + + params_dict = dict(self.named_parameters()) + for name, loaded_weight in weights: + if 'rotary_emb.inv_freq' in name: + continue + if ('rotary_emb.cos_cached' in name or 'rotary_emb.sin_cached' in name): + continue + if self.config.tie_word_embeddings and 'lm_head.weight' in name: + continue + + if '.experts' in name: + self._load_weight_experts(name, loaded_weight, params_dict, expert_params_mapping=expert_params_mapping) + else: + for (param_name, weight_name, shard_id) in stacked_params_mapping: + if weight_name not in name: + continue + name = name.replace(weight_name, param_name) + param = params_dict[name] + load_weight(param, loaded_weight, shard_id=shard_id) + break + else: + param = params_dict[name] + load_weight(param, loaded_weight) From c1e4cde70ad096444becce516817b5bb14247aa1 Mon Sep 17 00:00:00 2001 From: grimoire Date: Mon, 8 Sep 2025 14:07:26 +0800 Subject: [PATCH 19/29] bind block length --- lmdeploy/cli/utils.py | 2 +- lmdeploy/messages.py | 2 +- lmdeploy/pytorch/config.py | 6 ++---- lmdeploy/pytorch/configurations/sdar.py | 1 + lmdeploy/pytorch/strategies/dllm/__init__.py | 20 +++++++++++++++++++- 5 files changed, 24 insertions(+), 7 deletions(-) diff --git a/lmdeploy/cli/utils.py b/lmdeploy/cli/utils.py index bdfe2b5217..66608c4dc4 100644 --- a/lmdeploy/cli/utils.py +++ b/lmdeploy/cli/utils.py @@ -613,7 +613,7 @@ def logprobs_mode(parser): @staticmethod def dllm_block_length(parser): """dllm_block_length for dllm.""" - return parser.add_argument('--dllm-block-length', type=int, default=1, help='Block length for dllm') + return parser.add_argument('--dllm-block-length', type=int, default=None, help='Block length for dllm') @staticmethod def dllm_unmasking_strategy(parser): diff --git a/lmdeploy/messages.py b/lmdeploy/messages.py index 51a3748720..69ac361abc 100644 --- a/lmdeploy/messages.py +++ b/lmdeploy/messages.py @@ -377,7 +377,7 @@ class PytorchEngineConfig: logprobs_mode: str = None # dllm - dllm_block_length: int = 1 + dllm_block_length: int = None dllm_unmasking_strategy: str = 'low_confidence_dynamic' dllm_denoising_steps: int = None dllm_confidence_threshold: float = 0.85 diff --git a/lmdeploy/pytorch/config.py b/lmdeploy/pytorch/config.py index c75358d9cc..8e4a5e177a 100644 --- a/lmdeploy/pytorch/config.py +++ b/lmdeploy/pytorch/config.py @@ -202,6 +202,7 @@ class ModelConfig: use_flash_mla: bool = False model_paradigm: str = 'ar' dllm_mask_token: int = 0 + dllm_block_length: int = None def get_head_size(self): """Get head size.""" @@ -309,12 +310,9 @@ class MiscConfig: @classmethod def from_engine_config(cls, engine_config: PytorchEngineConfig): """From engine config.""" - denoising_steps = engine_config.dllm_denoising_steps - if denoising_steps is None: - denoising_steps = engine_config.dllm_block_length // 2 dllm_config = DLLMConfig(dllm_block_length=engine_config.dllm_block_length, unmasking_strategy=engine_config.dllm_unmasking_strategy, - denoising_steps=denoising_steps, + denoising_steps=engine_config.dllm_denoising_steps, confidence_threshold=engine_config.dllm_confidence_threshold) misc_config = cls(custom_module_map=engine_config.custom_module_map, empty_init=engine_config.empty_init, diff --git a/lmdeploy/pytorch/configurations/sdar.py b/lmdeploy/pytorch/configurations/sdar.py index edf9cd3cad..dcf323ce73 100644 --- a/lmdeploy/pytorch/configurations/sdar.py +++ b/lmdeploy/pytorch/configurations/sdar.py @@ -15,4 +15,5 @@ def build(cls, hf_config, model_path: str = None, **kwargs): cfg = DefaultModelConfigBuilder.build(hf_config, model_path, **kwargs) cfg.dllm_mask_token = 151669 cfg.model_paradigm = 'dllm' + cfg.dllm_block_length = 4 return cfg diff --git a/lmdeploy/pytorch/strategies/dllm/__init__.py b/lmdeploy/pytorch/strategies/dllm/__init__.py index 6b90bb6569..f704e8e036 100644 --- a/lmdeploy/pytorch/strategies/dllm/__init__.py +++ b/lmdeploy/pytorch/strategies/dllm/__init__.py @@ -21,7 +21,25 @@ def __init__(self, model_config: ModelConfig, dllm_config: DLLMConfig): """config.""" self.model_config = model_config self.dllm_config = dllm_config - self.dllm_block_length = self.dllm_config.dllm_block_length + + # update dllm_block_length + self.dllm_block_length = self._update_dllm_block_length() + + def _update_dllm_block_length(self): + """Update dllm_block_length.""" + if self.dllm_config.dllm_block_length is None: + dllm_block_length = self.model_config.dllm_block_length + else: + dllm_block_length = self.dllm_config.dllm_block_length + + assert dllm_block_length is not None, 'dllm_block_length should be set in model_config or dllm_config' + + self.dllm_config.dllm_block_length = dllm_block_length + self.model_config.dllm_block_length = dllm_block_length + + if self.dllm_config.denoising_steps is None: + self.dllm_config.denoising_steps = dllm_block_length + return dllm_block_length def build_cudagraph_strategy(self) -> 'CudagraphStrategy': """Build cudagraph strategy.""" From 26f4c2d55903fe80cb7ee1cefa35798ad84a475f Mon Sep 17 00:00:00 2001 From: grimoire Date: Fri, 12 Sep 2025 20:54:39 +0800 Subject: [PATCH 20/29] fix num loops --- lmdeploy/pytorch/strategies/dllm/engine.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/lmdeploy/pytorch/strategies/dllm/engine.py b/lmdeploy/pytorch/strategies/dllm/engine.py index 995bb49de3..93e91d6d22 100644 --- a/lmdeploy/pytorch/strategies/dllm/engine.py +++ b/lmdeploy/pytorch/strategies/dllm/engine.py @@ -31,5 +31,6 @@ def get_num_loops(self, is_decoding: bool) -> int: return 1 block_size = self.cache_config.block_size dllm_block_length = self.dllm_block_length - num_blocks = min(self.scheduler_config.prefill_interval // 2, block_size // dllm_block_length) - return num_blocks * dllm_block_length + max_num_loops = block_size // dllm_block_length * 2 + num_loops = min(self.scheduler_config.prefill_interval, max_num_loops) + return num_loops From 11674bf8ce56b6915ef35cd8991d04b6d8c83697 Mon Sep 17 00:00:00 2001 From: grimoire Date: Mon, 15 Sep 2025 11:35:29 +0800 Subject: [PATCH 21/29] enum unmasking type --- lmdeploy/pytorch/config.py | 26 +++++++++++++++++-- lmdeploy/pytorch/strategies/dllm/unmasking.py | 8 +++--- 2 files changed, 28 insertions(+), 6 deletions(-) diff --git a/lmdeploy/pytorch/config.py b/lmdeploy/pytorch/config.py index 8e4a5e177a..263978aa56 100644 --- a/lmdeploy/pytorch/config.py +++ b/lmdeploy/pytorch/config.py @@ -1,4 +1,5 @@ # Copyright (c) OpenMMLab. All rights reserved. +import enum from dataclasses import dataclass from typing import Any, Dict, List, Literal @@ -288,10 +289,30 @@ def from_hf_config(cls, return model_config +class UnmaskingStrategy(enum.Enum): + """Unmasking Strategy.""" + SEQUENTIAL = enum.auto() + LOW_CONFIDENCE_DYNAMIC = enum.auto() + LOW_CONFIDENCE_STATIC = enum.auto() + + @classmethod + def from_str(cls, strategy: str): + """From string.""" + strategy = strategy.lower() + if strategy == 'sequential': + return cls.SEQUENTIAL + elif strategy == 'low_confidence_dynamic': + return cls.LOW_CONFIDENCE_DYNAMIC + elif strategy == 'low_confidence_static': + return cls.LOW_CONFIDENCE_STATIC + else: + raise ValueError(f'Unknown unmasking strategy: {strategy}') + + @dataclass class DLLMConfig: dllm_block_length: int = 1 - unmasking_strategy: str = 'low_confidence_dynamic' + unmasking_strategy: UnmaskingStrategy = UnmaskingStrategy.LOW_CONFIDENCE_DYNAMIC denoising_steps: int = None confidence_threshold: float = 0.85 @@ -310,8 +331,9 @@ class MiscConfig: @classmethod def from_engine_config(cls, engine_config: PytorchEngineConfig): """From engine config.""" + dllm_unmasking_strategy = UnmaskingStrategy.from_str(engine_config.dllm_unmasking_strategy) dllm_config = DLLMConfig(dllm_block_length=engine_config.dllm_block_length, - unmasking_strategy=engine_config.dllm_unmasking_strategy, + unmasking_strategy=dllm_unmasking_strategy, denoising_steps=engine_config.dllm_denoising_steps, confidence_threshold=engine_config.dllm_confidence_threshold) misc_config = cls(custom_module_map=engine_config.custom_module_map, diff --git a/lmdeploy/pytorch/strategies/dllm/unmasking.py b/lmdeploy/pytorch/strategies/dllm/unmasking.py index 6b77102e8f..dba0a10a69 100644 --- a/lmdeploy/pytorch/strategies/dllm/unmasking.py +++ b/lmdeploy/pytorch/strategies/dllm/unmasking.py @@ -3,7 +3,7 @@ from torch.profiler import record_function from lmdeploy.pytorch import consts -from lmdeploy.pytorch.config import DLLMConfig +from lmdeploy.pytorch.config import DLLMConfig, UnmaskingStrategy DLLM_MASKED = consts.DLLM_MASKED DLLM_UNMASKED = consts.DLLM_UNMASKED @@ -105,11 +105,11 @@ def __call__(self, logits: torch.Tensor, input_ids: torch.Tensor, token_ids: tor dllm_mask = dllm_mask.flatten() token_ids = torch.where(dllm_mask != DLLM_MASKED, input_ids, token_ids) - if strategy == 'low_confidence_static': + if strategy == UnmaskingStrategy.LOW_CONFIDENCE_STATIC: dllm_mask = self.low_confidence_static(logits, token_ids, dllm_mask) - elif strategy == 'low_confidence_dynamic': + elif strategy == UnmaskingStrategy.LOW_CONFIDENCE_DYNAMIC: dllm_mask = self.low_confidence_dynamic(logits, token_ids, dllm_mask) - elif strategy == 'sequential': + elif strategy == UnmaskingStrategy.SEQUENTIAL: dllm_mask = self.sequential(dllm_mask) else: raise RuntimeError(f'strategy {strategy} not supported.') From 8fce74a0a643828d374e0d3c4c2ce92976df70c2 Mon Sep 17 00:00:00 2001 From: grimoire Date: Mon, 15 Sep 2025 11:41:03 +0800 Subject: [PATCH 22/29] typo fixing --- lmdeploy/pytorch/engine/model_agent.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/lmdeploy/pytorch/engine/model_agent.py b/lmdeploy/pytorch/engine/model_agent.py index 7b7d6a8905..daddcf4e3c 100644 --- a/lmdeploy/pytorch/engine/model_agent.py +++ b/lmdeploy/pytorch/engine/model_agent.py @@ -699,7 +699,7 @@ async def __prepare_dp(): next_token_ids, logprobs = await self.async_sampling_logits(last_logits, sampling_inputs, inputs) with self._broadcast_next_token(next_token_ids, dist_ctx, enable=need_broadcast_next): - logger.debug(f' rank[{rank}]: synchornize token ids [{idx}]') + logger.debug(f' rank[{rank}]: synchronize token ids [{idx}]') # post sampling next_token_ids, extra_inputs = self.agent_strategy.post_sampling( @@ -719,7 +719,7 @@ async def __prepare_dp(): # broadcast next token for TP > 1 with self._broadcast_next_token(next_token_ids, dist_ctx, enable=need_broadcast_next): - logger.debug(f' rank[{rank}]: synchornize token ids [{idx}]') + logger.debug(f' rank[{rank}]: synchronize token ids [{idx}]') # post sampling next_token_ids, extra_inputs = self.agent_strategy.post_sampling(inputs, last_logits, next_token_ids, From 94c3013525b83b2d29d76df00907793311b7d56b Mon Sep 17 00:00:00 2001 From: grimoire Date: Mon, 15 Sep 2025 15:18:00 +0800 Subject: [PATCH 23/29] warning --- lmdeploy/pytorch/configurations/sdar.py | 1 - lmdeploy/pytorch/strategies/dllm/__init__.py | 7 +++++++ 2 files changed, 7 insertions(+), 1 deletion(-) diff --git a/lmdeploy/pytorch/configurations/sdar.py b/lmdeploy/pytorch/configurations/sdar.py index dcf323ce73..edf9cd3cad 100644 --- a/lmdeploy/pytorch/configurations/sdar.py +++ b/lmdeploy/pytorch/configurations/sdar.py @@ -15,5 +15,4 @@ def build(cls, hf_config, model_path: str = None, **kwargs): cfg = DefaultModelConfigBuilder.build(hf_config, model_path, **kwargs) cfg.dllm_mask_token = 151669 cfg.model_paradigm = 'dllm' - cfg.dllm_block_length = 4 return cfg diff --git a/lmdeploy/pytorch/strategies/dllm/__init__.py b/lmdeploy/pytorch/strategies/dllm/__init__.py index f704e8e036..25736033c1 100644 --- a/lmdeploy/pytorch/strategies/dllm/__init__.py +++ b/lmdeploy/pytorch/strategies/dllm/__init__.py @@ -3,6 +3,7 @@ from lmdeploy.pytorch.config import DLLMConfig, ModelConfig from lmdeploy.pytorch.strategies.base.sequence import SequenceStrategy +from lmdeploy.utils import get_logger if TYPE_CHECKING: from lmdeploy.pytorch.strategies.base.cudagraph import CudagraphStrategy @@ -14,6 +15,8 @@ from ..base import StrategyFactoryBase +logger = get_logger('lmdeploy') + class DLLMStrategyFactory(StrategyFactoryBase): @@ -29,6 +32,10 @@ def _update_dllm_block_length(self): """Update dllm_block_length.""" if self.dllm_config.dllm_block_length is None: dllm_block_length = self.model_config.dllm_block_length + if dllm_block_length is None: + dllm_block_length = 4 + logger.warning('Model does not provide dllm_block_length. ' + f'Set dllm_block_length={dllm_block_length} as default.') else: dllm_block_length = self.dllm_config.dllm_block_length From c74b53571265c1e9ee4daf01717a6ff957295055 Mon Sep 17 00:00:00 2001 From: grimoire Date: Tue, 16 Sep 2025 12:22:58 +0800 Subject: [PATCH 24/29] fix metric --- lmdeploy/metrics/stats.py | 5 ++++- lmdeploy/pytorch/engine/engine.py | 4 ++++ 2 files changed, 8 insertions(+), 1 deletion(-) diff --git a/lmdeploy/metrics/stats.py b/lmdeploy/metrics/stats.py index b21eeaca90..6654b3d9d3 100644 --- a/lmdeploy/metrics/stats.py +++ b/lmdeploy/metrics/stats.py @@ -198,7 +198,10 @@ def update_from_output(self, outputs: EngineOutput, req_state: RequestState): outputs (EngineOutput): The output from the engine containing information about the current iteration. req_state (RequestState): The state of the request, including timestamps and token counts. """ - self.new_generation_tokens = outputs.num_token - req_state.generation_tokens + new_generation_tokens = outputs.num_token - req_state.generation_tokens + if new_generation_tokens == 0: + return + self.new_generation_tokens = new_generation_tokens if req_state.first_token_time == 0: # It means the first token is generated in this iteration req_state.first_token_time = outputs.req_metrics.token_timestamp diff --git a/lmdeploy/pytorch/engine/engine.py b/lmdeploy/pytorch/engine/engine.py index 3445b7bfe6..051951d501 100644 --- a/lmdeploy/pytorch/engine/engine.py +++ b/lmdeploy/pytorch/engine/engine.py @@ -846,6 +846,10 @@ def _make_infer_outputs( finish = msg.status == MessageStatus.STOPPED or msg.status == MessageStatus.TO_BE_MIGRATED if not finish and len(token_ids) == 0: continue + resp_data = msg.resp.data + if resp_data is not None and len(resp_data.get('token_ids', [])) == len(token_ids): + # no new tokens + continue session_id = msg.session_id if msg.resp_cache: cache_block_ids = self.scheduler.block_manager.get_block_table(msg).tolist() From bbd14891a6cbdef188e428f3ef5db4b581157e1a Mon Sep 17 00:00:00 2001 From: grimoire Date: Tue, 16 Sep 2025 12:55:24 +0800 Subject: [PATCH 25/29] limit batch size --- lmdeploy/pytorch/engine/engine.py | 9 +++++++++ lmdeploy/pytorch/strategies/dllm/engine.py | 14 ++++++++++++++ 2 files changed, 23 insertions(+) diff --git a/lmdeploy/pytorch/engine/engine.py b/lmdeploy/pytorch/engine/engine.py index 051951d501..eab98848b3 100644 --- a/lmdeploy/pytorch/engine/engine.py +++ b/lmdeploy/pytorch/engine/engine.py @@ -81,6 +81,15 @@ def _update_engine_config(engine_config: PytorchEngineConfig): if engine_config.max_batch_size is None: engine_config.max_batch_size = get_max_batch_size(engine_config.device_type) + if engine_config.dllm_block_length is not None: + max_prefill_token_num = engine_config.max_prefill_token_num + max_batch_size = engine_config.max_batch_size + if max_batch_size * engine_config.dllm_block_length > max_prefill_token_num: + engine_config.max_batch_size = max_prefill_token_num // engine_config.dllm_block_length + logger.warning(f'Update max_batch_size to {engine_config.max_batch_size} ' + f'since dllm_block_length({engine_config.dllm_block_length}) * max_batch_size ' + f'({max_batch_size}) > max_prefill_token_num ({max_prefill_token_num}).') + if engine_config.dp != 1: if engine_config.tp == 1 and engine_config.ep == 1: engine_config.dp = 1 diff --git a/lmdeploy/pytorch/strategies/dllm/engine.py b/lmdeploy/pytorch/strategies/dllm/engine.py index 93e91d6d22..32244a7abc 100644 --- a/lmdeploy/pytorch/strategies/dllm/engine.py +++ b/lmdeploy/pytorch/strategies/dllm/engine.py @@ -2,9 +2,12 @@ from functools import lru_cache from lmdeploy.pytorch.config import CacheConfig, SchedulerConfig +from lmdeploy.utils import get_logger from ..base.engine import EngineStrategy +logger = get_logger('lmdeploy') + class DLLMEngineStrategy(EngineStrategy): """DLLM Engine Strategy.""" @@ -14,6 +17,17 @@ def __init__(self, scheduler_config: SchedulerConfig, cache_config: CacheConfig, self.cache_config = cache_config self.dllm_block_length = dllm_block_length + self._check() + + def _check(self): + """check.""" + max_prefill_token_num = self.cache_config.max_prefill_token_num + max_batches = self.cache_config.max_batches + if self.dllm_block_length * max_batches > max_prefill_token_num: + logger.warning(f'dllm_block_length({self.dllm_block_length}) * max_batch_size ({max_batches}) ' + f'> max_prefill_token_num ({max_prefill_token_num}). ' + 'This may lead to OOM. Consider to reduce max_batch_size or dllm_block_length.') + @lru_cache(maxsize=2) def get_prealloc_size(self, is_decoding: bool) -> int: """Get prealloc_size.""" From e8771be6dbd8c93a9882dafcf6a686e28d71fd58 Mon Sep 17 00:00:00 2001 From: "q.yao" Date: Thu, 18 Sep 2025 16:13:53 +0800 Subject: [PATCH 26/29] rename field;comment unmasking strategy --- lmdeploy/pytorch/config.py | 8 ++++++-- lmdeploy/pytorch/models/sdar.py | 2 +- lmdeploy/pytorch/models/sdar_moe.py | 2 +- lmdeploy/pytorch/strategies/dllm/__init__.py | 6 +++--- lmdeploy/pytorch/strategies/dllm/model_agent.py | 2 +- lmdeploy/pytorch/strategies/dllm/unmasking.py | 10 +++++----- 6 files changed, 17 insertions(+), 13 deletions(-) diff --git a/lmdeploy/pytorch/config.py b/lmdeploy/pytorch/config.py index 263978aa56..5c9639a71b 100644 --- a/lmdeploy/pytorch/config.py +++ b/lmdeploy/pytorch/config.py @@ -291,8 +291,12 @@ def from_hf_config(cls, class UnmaskingStrategy(enum.Enum): """Unmasking Strategy.""" + + # unmasking from left to right SEQUENTIAL = enum.auto() + # unmasking with confidence threshold LOW_CONFIDENCE_DYNAMIC = enum.auto() + # unmasking with topk in a block LOW_CONFIDENCE_STATIC = enum.auto() @classmethod @@ -311,7 +315,7 @@ def from_str(cls, strategy: str): @dataclass class DLLMConfig: - dllm_block_length: int = 1 + block_length: int = 1 unmasking_strategy: UnmaskingStrategy = UnmaskingStrategy.LOW_CONFIDENCE_DYNAMIC denoising_steps: int = None confidence_threshold: float = 0.85 @@ -332,7 +336,7 @@ class MiscConfig: def from_engine_config(cls, engine_config: PytorchEngineConfig): """From engine config.""" dllm_unmasking_strategy = UnmaskingStrategy.from_str(engine_config.dllm_unmasking_strategy) - dllm_config = DLLMConfig(dllm_block_length=engine_config.dllm_block_length, + dllm_config = DLLMConfig(block_length=engine_config.dllm_block_length, unmasking_strategy=dllm_unmasking_strategy, denoising_steps=engine_config.dllm_denoising_steps, confidence_threshold=engine_config.dllm_confidence_threshold) diff --git a/lmdeploy/pytorch/models/sdar.py b/lmdeploy/pytorch/models/sdar.py index 9f69c525e0..6a624e40e4 100644 --- a/lmdeploy/pytorch/models/sdar.py +++ b/lmdeploy/pytorch/models/sdar.py @@ -301,7 +301,7 @@ def __init__(self, super().__init__() self.config = config self.ctx_mgr = ctx_mgr - config.dllm_block_length = ctx_mgr.build_ctx.dllm_config.dllm_block_length + config.dllm_block_length = ctx_mgr.build_ctx.dllm_config.block_length # build model self.model = SDARModel(config, dtype=dtype, device=device) # build lm_head diff --git a/lmdeploy/pytorch/models/sdar_moe.py b/lmdeploy/pytorch/models/sdar_moe.py index d7c569402a..522d2aed95 100644 --- a/lmdeploy/pytorch/models/sdar_moe.py +++ b/lmdeploy/pytorch/models/sdar_moe.py @@ -370,7 +370,7 @@ def __init__(self, super().__init__() self.config = config self.ctx_mgr = ctx_mgr - config.dllm_block_length = ctx_mgr.build_ctx.dllm_config.dllm_block_length + config.dllm_block_length = ctx_mgr.build_ctx.dllm_config.block_length # build model self.model = SDARMoeModel(config, dtype=dtype, device=device) # build lm_head diff --git a/lmdeploy/pytorch/strategies/dllm/__init__.py b/lmdeploy/pytorch/strategies/dllm/__init__.py index 25736033c1..dc0395a017 100644 --- a/lmdeploy/pytorch/strategies/dllm/__init__.py +++ b/lmdeploy/pytorch/strategies/dllm/__init__.py @@ -30,18 +30,18 @@ def __init__(self, model_config: ModelConfig, dllm_config: DLLMConfig): def _update_dllm_block_length(self): """Update dllm_block_length.""" - if self.dllm_config.dllm_block_length is None: + if self.dllm_config.block_length is None: dllm_block_length = self.model_config.dllm_block_length if dllm_block_length is None: dllm_block_length = 4 logger.warning('Model does not provide dllm_block_length. ' f'Set dllm_block_length={dllm_block_length} as default.') else: - dllm_block_length = self.dllm_config.dllm_block_length + dllm_block_length = self.dllm_config.block_length assert dllm_block_length is not None, 'dllm_block_length should be set in model_config or dllm_config' - self.dllm_config.dllm_block_length = dllm_block_length + self.dllm_config.block_length = dllm_block_length self.model_config.dllm_block_length = dllm_block_length if self.dllm_config.denoising_steps is None: diff --git a/lmdeploy/pytorch/strategies/dllm/model_agent.py b/lmdeploy/pytorch/strategies/dllm/model_agent.py index 1f98bb5198..a5104d4981 100644 --- a/lmdeploy/pytorch/strategies/dllm/model_agent.py +++ b/lmdeploy/pytorch/strategies/dllm/model_agent.py @@ -108,7 +108,7 @@ def step(self, class DLLMModelAgentStrategy(ModelAgentStrategy): def __init__(self, dllm_config: DLLMConfig, dllm_mask_token: int): - block_size = dllm_config.dllm_block_length + block_size = dllm_config.block_length self.block_size = block_size self.dllm_mask_token = dllm_mask_token diff --git a/lmdeploy/pytorch/strategies/dllm/unmasking.py b/lmdeploy/pytorch/strategies/dllm/unmasking.py index dba0a10a69..7c24ac8d3f 100644 --- a/lmdeploy/pytorch/strategies/dllm/unmasking.py +++ b/lmdeploy/pytorch/strategies/dllm/unmasking.py @@ -23,7 +23,7 @@ def _get_scores(self, logits: torch.Tensor, token_ids: torch.Tensor): def _get_denoise_num(self): """Get denoise num.""" - block_size = self.dllm_config.dllm_block_length + block_size = self.dllm_config.block_length denoising_steps = self.dllm_config.denoising_steps if denoising_steps is None: denoising_steps = block_size @@ -33,7 +33,7 @@ def _get_denoise_num(self): def low_confidence_static(self, logits: torch.Tensor, token_ids: torch.Tensor, dllm_mask: torch.Tensor): """static.""" - block_size = self.dllm_config.dllm_block_length + block_size = self.dllm_config.block_length topk = self._get_denoise_num() scores = self._get_scores(logits, token_ids) is_masked = dllm_mask == DLLM_MASKED @@ -50,7 +50,7 @@ def low_confidence_static(self, logits: torch.Tensor, token_ids: torch.Tensor, d def low_confidence_dynamic(self, logits: torch.Tensor, token_ids: torch.Tensor, dllm_mask: torch.Tensor): """dynamic.""" - block_size = self.dllm_config.dllm_block_length + block_size = self.dllm_config.block_length threshold = self.dllm_config.confidence_threshold scores = self._get_scores(logits, token_ids) is_masked = dllm_mask == DLLM_MASKED @@ -68,7 +68,7 @@ def low_confidence_dynamic(self, logits: torch.Tensor, token_ids: torch.Tensor, def sequential(self, dllm_mask: torch.Tensor): """sequential.""" - block_size = self.dllm_config.dllm_block_length + block_size = self.dllm_config.block_length denoise_num = self._get_denoise_num() dllm_mask = dllm_mask.view(-1, block_size) is_masked = dllm_mask == DLLM_MASKED @@ -93,7 +93,7 @@ def __call__(self, logits: torch.Tensor, input_ids: torch.Tensor, token_ids: tor return dllm_mask # reshape to [num_blocks, block_size] - block_size = self.dllm_config.dllm_block_length + block_size = self.dllm_config.block_length dllm_mask = dllm_mask.unflatten(0, (-1, block_size)) is_same = (dllm_mask == dllm_mask[:, :1]).all(dim=1) From 59c7c62ccffab514c059d7401bf554967847a4ff Mon Sep 17 00:00:00 2001 From: "q.yao" Date: Thu, 18 Sep 2025 17:31:58 +0800 Subject: [PATCH 27/29] suppression warning --- lmdeploy/pytorch/backends/cuda/attention.py | 28 ++++++++++++++++++--- lmdeploy/pytorch/check_env/transformers.py | 2 +- lmdeploy/pytorch/config.py | 5 +++- lmdeploy/pytorch/disagg/backend/__init__.py | 2 +- lmdeploy/tokenizer.py | 7 ++++-- 5 files changed, 36 insertions(+), 8 deletions(-) diff --git a/lmdeploy/pytorch/backends/cuda/attention.py b/lmdeploy/pytorch/backends/cuda/attention.py index b5ecc3cbef..b241c384b2 100644 --- a/lmdeploy/pytorch/backends/cuda/attention.py +++ b/lmdeploy/pytorch/backends/cuda/attention.py @@ -1,5 +1,6 @@ # Copyright (c) OpenMMLab. All rights reserved. +import functools from dataclasses import dataclass from typing import Literal @@ -20,8 +21,8 @@ assert torch.ops.flash_attn_3 is not None use_fa3 = True except Exception: - logger.warning('For higher performance, please install FlashAttention-3 ' - 'https://github.com/Dao-AILab/flash-attention') + logger.debug('For higher performance, please install FlashAttention-3 ' + 'https://github.com/Dao-AILab/flash-attention') @dataclass @@ -221,6 +222,15 @@ def forward( return attn_output +@functools.lru_cache +def use_fa3_warning(): + if use_fa3: + return True + logger.warning('For higher performance, please install FlashAttention-3 ' + 'https://github.com/Dao-AILab/flash-attention') + return False + + class FlashMLAImpl(TritonAttentionImpl): def __init__( @@ -255,6 +265,7 @@ def __init__( from lmdeploy.pytorch.kernels.cuda import flash_mla_fwd self.flash_mla_fwd = flash_mla_fwd assert num_kv_heads == 1, 'MLA requires num kv heads equal to 1' + use_fa3_warning() def forward( self, @@ -515,6 +526,14 @@ def forward( return attn_output +@functools.lru_cache +def _enable_fa3(alibi: bool, learnable_sink: bool, block_sparse_size: int): + enable = not alibi and not learnable_sink and block_sparse_size == 1 + if enable and not use_fa3_warning(): + enable = False + return enable + + class TritonAttentionBuilder(AttentionBuilder[TritonAttentionMetadata]): """Triton attention builder.""" @@ -535,8 +554,9 @@ def build( **kwargs, ) -> TritonAttentionImpl: """build.""" - enable_fa3 = use_fa3 and not alibi and not learnable_sink and block_sparse_size == 1 + enable_fa3 = _enable_fa3(alibi, learnable_sink, block_sparse_size) if use_flash_mla is True: + logger.debug('Build FlashMLAImpl Attention') return FlashMLAImpl(num_heads, head_size, scale=scale, @@ -548,6 +568,7 @@ def build( causal=causal, **kwargs) elif enable_fa3: + logger.debug('Build FA3Impl Attention') return FA3Impl(num_heads, head_size, scale=scale, @@ -559,6 +580,7 @@ def build( causal=causal, **kwargs) else: + logger.debug('Build TritonAttentionImpl Attention') return TritonAttentionImpl(num_heads, head_size, scale=scale, diff --git a/lmdeploy/pytorch/check_env/transformers.py b/lmdeploy/pytorch/check_env/transformers.py index defd43303b..20102b1deb 100644 --- a/lmdeploy/pytorch/check_env/transformers.py +++ b/lmdeploy/pytorch/check_env/transformers.py @@ -4,7 +4,7 @@ from .base import BaseChecker MIN_TRANSFORMERS_VERSION = '4.33.0' -MAX_TRANSFORMERS_VERSION = '4.53.3' +MAX_TRANSFORMERS_VERSION = '4.56.1' class TransformersChecker(BaseChecker): diff --git a/lmdeploy/pytorch/config.py b/lmdeploy/pytorch/config.py index 5c9639a71b..ac3459e045 100644 --- a/lmdeploy/pytorch/config.py +++ b/lmdeploy/pytorch/config.py @@ -28,7 +28,10 @@ def _update_torch_dtype(config: 'ModelConfig', dtype: str): config.dtype = torch.float16 return config - torch_dtype = getattr(config.hf_config, 'torch_dtype', None) + torch_dtype = getattr(config.hf_config, 'dtype', None) + if torch_dtype is None: + torch_dtype = getattr(config.hf_config, 'torch_dtype', None) + # deal with case when torch_dtype is not string but torch.dtype if isinstance(torch_dtype, torch.dtype): torch_dtype = str(torch_dtype).split('.')[1] diff --git a/lmdeploy/pytorch/disagg/backend/__init__.py b/lmdeploy/pytorch/disagg/backend/__init__.py index 3ab02d3bd6..2309e16c02 100644 --- a/lmdeploy/pytorch/disagg/backend/__init__.py +++ b/lmdeploy/pytorch/disagg/backend/__init__.py @@ -7,7 +7,7 @@ logger.debug('Registering DLSlime Backend') from .dlslime import DLSlimeBackend except ImportError: - logger.warning('Disable DLSlime Backend') + logger.debug('Disable DLSlime Backend') try: logger.debug('Registering Mooncake Backend') diff --git a/lmdeploy/tokenizer.py b/lmdeploy/tokenizer.py index acf2ede82d..a05638c746 100644 --- a/lmdeploy/tokenizer.py +++ b/lmdeploy/tokenizer.py @@ -423,8 +423,11 @@ class Tokenizer: """ def __init__(self, model_path: str): - from transformers import PretrainedConfig - model_cfg = PretrainedConfig.from_pretrained(model_path, trust_remote_code=True) + from transformers import AutoConfig, PretrainedConfig + try: + model_cfg = AutoConfig.from_pretrained(model_path, trust_remote_code=True) + except BaseException: + model_cfg = PretrainedConfig.from_pretrained(model_path, trust_remote_code=True) is_gpt_oss = getattr(model_cfg, 'model_type', '') == 'gpt_oss' from transformers.models.auto.tokenization_auto import get_tokenizer_config tokenizer_config = get_tokenizer_config(model_path, trust_remote_code=True) From 1e47c310405b916c3e82b37d676ef0bdedd29f17 Mon Sep 17 00:00:00 2001 From: "q.yao" Date: Thu, 18 Sep 2025 20:45:22 +0800 Subject: [PATCH 28/29] colored vis --- lmdeploy/pytorch/tools/utils.py | 66 ++++++++++++++++++++++++--------- 1 file changed, 49 insertions(+), 17 deletions(-) diff --git a/lmdeploy/pytorch/tools/utils.py b/lmdeploy/pytorch/tools/utils.py index 32b707b140..98d42772c0 100644 --- a/lmdeploy/pytorch/tools/utils.py +++ b/lmdeploy/pytorch/tools/utils.py @@ -132,6 +132,13 @@ def visualize_pipe_out(outputs, enable_meta: bool = True): from lmdeploy.messages import Response + try: + from termcolor import colored + except ImportError: + + def colored(text, color=None, on_color=None, attrs=None): + return text + if isinstance(outputs, Response): outputs = [outputs] try: @@ -139,24 +146,49 @@ def visualize_pipe_out(outputs, enable_meta: bool = True): except Exception: term_size = 100 - def _lined_print(msg: str, line_format: str = '-', full_line: bool = False): - print(msg) - if full_line: - columns = term_size - else: - columns = max(len(m) for m in msg.split('\n')) - print(line_format * columns) + border_color = 'cyan' + meta_color = 'light_grey' + number_color = 'green' + + def _print_title(title: str, color: str = border_color): + title_text = f' {title} ' + print(colored(f'【{title_text}】', color, attrs=['bold'])) + + def _print_section(title: str, content: str, color: str = border_color): + """Simple title and content printing.""" + _print_title(title, color) + print(content) + + def _print_meta(out: Response): + """Enhanced meta information display.""" + # Create a clean table-like format + finish_color = 'yellow' if out.finish_reason == 'stop' else 'red' + meta_content = [ + f"{colored('• Input Tokens:', meta_color)} {colored(out.input_token_len, number_color)}", + f"{colored('• Generated Tokens:', meta_color)} {colored(out.generate_token_len, number_color)}", + f"{colored('• Finish Reason:', meta_color)} {colored(out.finish_reason, finish_color)}" + ] + + lines = '\n'.join(meta_content) + lines += '\n' + _print_section('METADATA', lines, border_color) + + # Main loop + print(colored('━' * term_size, border_color)) outputs: List[Response] = outputs - term_line = '—' * term_size - print(term_line) for idx, out in enumerate(outputs): - _lined_print(f'output[{idx}]', '=') + header = f'OUTPUT [{idx + 1}/{len(outputs)}]' + header_formatted = colored(f'✦ {header}', 'light_magenta', attrs=['bold']) + print(header_formatted) + print() + if enable_meta: - _lined_print('meta', '-') - _lined_print( - f'input_token_len={out.input_token_len}\n' - f'generate_token_len={out.generate_token_len}\n' - f'finish_reason="{out.finish_reason}"', '—') - _lined_print('text', '-') - _lined_print(f'{out.text}', '—', full_line=True) + _print_meta(out) + + _print_section('TEXT', out.text, border_color) + + if idx < len(outputs) - 1: # Add separator when it's not the last output + print(colored('─' * (term_size), border_color, attrs=['dark'])) + else: + print(colored('━' * term_size, border_color)) From ee71d91df61f5d3d0df03eda1da2b69efe3b059e Mon Sep 17 00:00:00 2001 From: "q.yao" Date: Thu, 18 Sep 2025 20:53:22 +0800 Subject: [PATCH 29/29] fix dummy --- lmdeploy/pytorch/strategies/base/model_inputs.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/lmdeploy/pytorch/strategies/base/model_inputs.py b/lmdeploy/pytorch/strategies/base/model_inputs.py index a8b97c775e..f27134077d 100644 --- a/lmdeploy/pytorch/strategies/base/model_inputs.py +++ b/lmdeploy/pytorch/strategies/base/model_inputs.py @@ -21,7 +21,7 @@ def make_dummy_inputs(batch_size: int, 1, num_tokens, ), dtype=torch.long, device=device) - seq_length = torch.ones((batch_size, ), dtype=torch.long, device=device) + seq_length = torch.full((batch_size, ), max_q_seqlen, dtype=torch.long, device=device) history_lengths = torch.zeros((batch_size, ), dtype=torch.long, device=device) block_offsets = torch.full((batch_size, 1), dummy_block_id, dtype=torch.long, device=device) num_ignored_history = torch.zeros((batch_size, ), dtype=torch.long, device=device)