diff --git a/python/sglang/srt/managers/schedule_batch.py b/python/sglang/srt/managers/schedule_batch.py index 075d9047737..b34998cc8ad 100644 --- a/python/sglang/srt/managers/schedule_batch.py +++ b/python/sglang/srt/managers/schedule_batch.py @@ -62,6 +62,7 @@ ) from sglang.srt.mem_cache.base_prefix_cache import BasePrefixCache from sglang.srt.mem_cache.chunk_cache import ChunkCache, SWAChunkCache +from sglang.srt.mem_cache.mamba_radix_cache import MambaRadixCache from sglang.srt.mem_cache.memory_pool import HybridReqToTokenPool, ReqToTokenPool from sglang.srt.mem_cache.radix_cache import RadixKey from sglang.srt.mem_cache.swa_radix_cache import SWARadixCache @@ -506,6 +507,7 @@ def __init__( # Memory pool info self.req_pool_idx: Optional[int] = None + self.mamba_pool_idx: Optional[torch.Tensor] = None # shape (1) # Check finish self.tokenizer = None @@ -711,7 +713,12 @@ def init_next_round_input(self, tree_cache: Optional[BasePrefixCache] = None): self.last_host_node, self.host_hit_length, ) = tree_cache.match_prefix( - key=RadixKey(token_ids=token_ids, extra_key=self.extra_key) + key=RadixKey(token_ids=token_ids, extra_key=self.extra_key), + **( + {"req": self, "cow_mamba": True} + if isinstance(tree_cache, MambaRadixCache) + else {} + ), ) self.last_matched_prefix_len = len(self.prefix_indices) self.extend_input_len = len(self.fill_ids) - len(self.prefix_indices) @@ -807,6 +814,7 @@ def reset_for_retract(self): self.extend_logprob_start_len = 0 self.is_chunked = 0 self.req_pool_idx = None + self.mamba_pool_idx = None self.already_computed = 0 def offload_kv_cache(self, req_to_token_pool, token_to_kv_pool_allocator): @@ -1003,6 +1011,13 @@ def is_empty(self): def alloc_req_slots(self, num_reqs: int, reqs: Optional[List[Req]] = None): if isinstance(self.req_to_token_pool, HybridReqToTokenPool): + mamba_available_size = self.req_to_token_pool.mamba_pool.available_size() + if mamba_available_size < num_reqs: + if self.tree_cache is not None and isinstance( + self.tree_cache, MambaRadixCache + ): + mamba_num = max(0, num_reqs - mamba_available_size) + self.tree_cache.evict_mamba(mamba_num) req_pool_indices = self.req_to_token_pool.alloc(num_reqs, reqs) else: req_pool_indices = self.req_to_token_pool.alloc(num_reqs) diff --git a/python/sglang/srt/managers/schedule_policy.py b/python/sglang/srt/managers/schedule_policy.py index 2fb355b031e..288984bb872 100644 --- a/python/sglang/srt/managers/schedule_policy.py +++ b/python/sglang/srt/managers/schedule_policy.py @@ -27,6 +27,7 @@ from sglang.srt.managers.schedule_batch import Req, ScheduleBatch from sglang.srt.mem_cache.allocator import SWATokenToKVPoolAllocator from sglang.srt.mem_cache.base_prefix_cache import BasePrefixCache +from sglang.srt.mem_cache.mamba_radix_cache import MambaRadixCache from sglang.srt.mem_cache.radix_cache import RadixCache, RadixKey, TreeNode from sglang.srt.server_args import ServerArgs @@ -357,6 +358,7 @@ def __init__( self.is_hybrid = isinstance( self.token_to_kv_pool_allocator, SWATokenToKVPoolAllocator ) + self.is_hybrid_gdn_cache = isinstance(self.tree_cache, MambaRadixCache) self.priority_scheduling_preemption_threshold = ( priority_scheduling_preemption_threshold @@ -380,6 +382,11 @@ def rem_total_tokens(self): self.token_to_kv_pool_allocator.swa_available_size() + self.tree_cache.swa_evictable_size(), ) + elif self.is_hybrid_gdn_cache: + available_and_evictable = ( + self.token_to_kv_pool_allocator.available_size() + + self.tree_cache.full_evictable_size() + ) else: available_and_evictable = ( self.token_to_kv_pool_allocator.available_size() @@ -397,6 +404,11 @@ def cur_rem_tokens(self): self.token_to_kv_pool_allocator.swa_available_size() + self.tree_cache.swa_evictable_size(), ) + elif self.is_hybrid_gdn_cache: + available_and_evictable = ( + self.token_to_kv_pool_allocator.available_size() + + self.tree_cache.full_evictable_size() + ) else: available_and_evictable = ( self.token_to_kv_pool_allocator.available_size() diff --git a/python/sglang/srt/managers/scheduler.py b/python/sglang/srt/managers/scheduler.py index d4c8d590258..7130f9b36e4 100644 --- a/python/sglang/srt/managers/scheduler.py +++ b/python/sglang/srt/managers/scheduler.py @@ -146,6 +146,7 @@ from sglang.srt.managers.utils import validate_input_length from sglang.srt.mem_cache.chunk_cache import ChunkCache, SWAChunkCache from sglang.srt.mem_cache.hiradix_cache import HiRadixCache +from sglang.srt.mem_cache.mamba_radix_cache import MambaRadixCache from sglang.srt.mem_cache.radix_cache import RadixCache from sglang.srt.mem_cache.swa_radix_cache import SWARadixCache from sglang.srt.model_executor.forward_batch_info import ( @@ -497,6 +498,10 @@ def __init__( # Hybrid memory pool self.is_hybrid = self.tp_worker.is_hybrid + self.is_hybrid_gdn = ( + self.tp_worker.worker.model_runner.hybrid_gdn_config is not None + ) + if self.is_hybrid: self.sliding_window_size = self.tp_worker.sliding_window_size self.full_tokens_per_layer, self.swa_tokens_per_layer = ( @@ -799,6 +804,16 @@ def init_memory_pool_and_cache(self): disable=server_args.disable_radix_cache, is_eagle=self.spec_algorithm.is_eagle(), ) + elif self.is_hybrid_gdn: + assert ( + self.server_args.disaggregation_mode == "null" + ), "Hybrid GDN mode does not support disaggregation yet" + self.tree_cache = MambaRadixCache( + req_to_token_pool=self.req_to_token_pool, + token_to_kv_pool_allocator=self.token_to_kv_pool_allocator, + page_size=self.page_size, + disable=server_args.disable_radix_cache, + ) elif server_args.enable_lmcache: from sglang.srt.mem_cache.storage.lmcache.lmc_radix_cache import ( LMCRadixCache, @@ -1670,6 +1685,25 @@ def check_memory(self): f"{self.full_tokens_per_layer=}, {full_available_size=}, {full_evictable_size=}, {self.tree_cache.full_protected_size()=}\n" f"{self.swa_tokens_per_layer=}, {swa_available_size=}, {swa_evictable_size=}, {self.tree_cache.swa_protected_size()=}\n" ) + elif self.is_hybrid_gdn and isinstance(self.tree_cache, MambaRadixCache): + ( + full_num_used, + mamba_num_used, + _, + _, + full_available_size, + full_evictable_size, + mamba_available_size, + mamba_evictable_size, + ) = self._get_mamba_token_info() + memory_leak = ( + full_num_used != self.tree_cache.full_protected_size() + or mamba_num_used != self.tree_cache.mamba_protected_size() + ) + token_msg = ( + f"{full_available_size=}, {full_evictable_size=}, {self.token_to_kv_pool_allocator.size=}, {self.tree_cache.full_protected_size()=}\n" + f"{mamba_available_size=}, {mamba_evictable_size=}, {self.req_to_token_pool.mamba_pool.size=}, {self.tree_cache.mamba_protected_size()=}\n" + ) else: _, _, available_size, evictable_size = self._get_token_info() protected_size = self.tree_cache.protected_size() @@ -1720,6 +1754,17 @@ def check_memory(self): ) = self._get_swa_token_info() num_used = max(full_num_used, swa_num_used) token_usage = max(full_token_usage, swa_token_usage) + elif self.is_hybrid_gdn: + ( + num_used, + _, + token_usage, + _, + _, + _, + _, + _, + ) = self._get_mamba_token_info() else: num_used, token_usage, _, _ = self._get_token_info() num_running_reqs = len(self.running_batch.reqs) @@ -1747,7 +1792,9 @@ def check_memory(self): self._publish_kv_events() def check_tree_cache(self): - if self.is_hybrid and isinstance(self.tree_cache, SWARadixCache): + if (self.is_hybrid and isinstance(self.tree_cache, SWARadixCache)) or ( + self.is_hybrid_gdn and isinstance(self.tree_cache, MambaRadixCache) + ): self.tree_cache.sanity_check() def _get_token_info(self): @@ -1757,6 +1804,35 @@ def _get_token_info(self): token_usage = num_used / self.max_total_num_tokens return num_used, token_usage, available_size, evictable_size + def _get_mamba_token_info(self): + is_radix_tree = isinstance(self.tree_cache, MambaRadixCache) + full_available_size = self.token_to_kv_pool_allocator.available_size() + full_evictable_size = ( + self.tree_cache.full_evictable_size() if is_radix_tree else 0 + ) + mamba_available_size = self.req_to_token_pool.mamba_pool.available_size() + mamba_evictable_size = ( + self.tree_cache.mamba_evictable_size() if is_radix_tree else 0 + ) + full_num_used = self.token_to_kv_pool_allocator.size - ( + full_available_size + full_evictable_size + ) + mamba_num_used = self.req_to_token_pool.mamba_pool.size - ( + mamba_available_size + mamba_evictable_size + ) + full_token_usage = full_num_used / self.token_to_kv_pool_allocator.size + mamba_usage = mamba_num_used / self.req_to_token_pool.mamba_pool.size + return ( + full_num_used, + mamba_num_used, + full_token_usage, + mamba_usage, + full_available_size, + full_evictable_size, + mamba_available_size, + mamba_evictable_size, + ) + def _get_swa_token_info(self): full_available_size = self.token_to_kv_pool_allocator.full_available_size() full_evictable_size = self.tree_cache.full_evictable_size() diff --git a/python/sglang/srt/managers/scheduler_metrics_mixin.py b/python/sglang/srt/managers/scheduler_metrics_mixin.py index 4fa4bfee1dc..a0051c7302a 100644 --- a/python/sglang/srt/managers/scheduler_metrics_mixin.py +++ b/python/sglang/srt/managers/scheduler_metrics_mixin.py @@ -104,6 +104,23 @@ def log_prefill_stats( f"full token usage: {full_token_usage:.2f}, " f"swa token usage: {swa_token_usage:.2f}, " ) + elif self.is_hybrid_gdn: + ( + full_num_used, + _, + full_token_usage, + mamba_usage, + _, + _, + _, + _, + ) = self._get_mamba_token_info() + num_used = full_num_used + token_usage = full_token_usage + token_usage_msg = ( + f"full token usage: {full_token_usage:.2f}, " + f"mamba usage: {mamba_usage:.2f}, " + ) else: num_used, token_usage, _, _ = self._get_token_info() token_usage_msg = f"token usage: {token_usage:.2f}, " @@ -203,6 +220,25 @@ def log_decode_stats( f"#swa token: {swa_num_used}, " f"swa token usage: {swa_token_usage:.2f}, " ) + elif self.is_hybrid_gdn: + ( + full_num_used, + mamba_used, + full_token_usage, + mamba_usage, + _, + _, + _, + _, + ) = self._get_mamba_token_info() + num_used = full_num_used + token_usage = full_token_usage + token_usage_msg = ( + f"#full token: {full_num_used}, " + f"full token usage: {full_token_usage:.2f}, " + f"mamba num: {mamba_used}, " + f"mamba usage: {mamba_usage:.2f}, " + ) else: num_used, token_usage, _, _ = self._get_token_info() token_usage_msg = f"#token: {num_used}, token usage: {token_usage:.2f}, " diff --git a/python/sglang/srt/mem_cache/mamba_radix_cache.py b/python/sglang/srt/mem_cache/mamba_radix_cache.py new file mode 100644 index 00000000000..7467daa5d56 --- /dev/null +++ b/python/sglang/srt/mem_cache/mamba_radix_cache.py @@ -0,0 +1,995 @@ +from __future__ import annotations + +""" +Copyright 2023-2024 SGLang Team +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" + +""" +The radix tree data structure for managing the hybrid (full and Mamba) KV cache. +""" + +import heapq +import time +from collections import defaultdict +from functools import partial +from typing import TYPE_CHECKING, List, Optional, Tuple + +import torch + +from sglang.srt.mem_cache.allocator import TokenToKVPoolAllocator +from sglang.srt.mem_cache.base_prefix_cache import BasePrefixCache, MatchResult +from sglang.srt.mem_cache.memory_pool import HybridReqToTokenPool +from sglang.srt.mem_cache.radix_cache import ( + RadixKey, + _key_match_page_size1, + _key_match_paged, + get_child_key, +) + +if TYPE_CHECKING: + from sglang.srt.managers.schedule_batch import Req + +import logging + +logger = logging.getLogger(__name__) + + +class TreeNode: + + counter = 0 + + def __init__(self, id: Optional[int] = None): + self.children = defaultdict(TreeNode) + self.parent: TreeNode = None + self.key: RadixKey = None + self.value: Optional[torch.Tensor] = None + self.mamba_value: Optional[torch.Tensor] = None + # invariant: for any node, if mamba_lock_ref is locked, full_lock_ref must be locked; + # if full_lock_ref is locked, mamba_lock_ref doesn't need to be locked. So, + # full_lock_ref is always >= mamba_lock_ref. + # for full_lock, once it is locked, its parent must be locked as well + # for mamba_lock, it only need lock node itself + self.full_lock_ref = 0 + self.mamba_lock_ref = 0 + # last access time is only used for sanity check. LRU is maintained by the lru list. + self.last_access_time = time.monotonic() + + self.hit_count = 0 + # store the host indices of KV cache + self.host_value = None + + # for lru list, invariant: + # 1. prev has greater last_access_time + # 2. next has smaller last_access_time + self.prev = None + self.next = None + self.mamba_prev = None + self.mamba_next = None + + self.id = TreeNode.counter if id is None else id + TreeNode.counter += 1 + + @property + def evicted(self): + return self.value is None + + @property + def backuped(self): + return self.host_value is not None + + def __lt__(self, other: "TreeNode"): + return self.last_access_time < other.last_access_time + + +class LRUList: + def __init__(self, mamba: bool = False): + self.mamba = mamba + if self.mamba: + self.prv = "mamba_prev" + self.nxt = "mamba_next" + self.lock_ref = "mamba_lock_ref" + else: + self.prv = "prev" + self.nxt = "next" + self.lock_ref = "full_lock_ref" + # Initialize dummy head and tail nodes + self.head = TreeNode() # Most recently used side + self.tail = TreeNode() # Least recently used side + setattr(self.head, self.nxt, self.tail) # self.head.next = self.tail + setattr(self.tail, self.prv, self.head) # self.tail.prev = self.head + self.cache = {} + + def _add_node(self, node): + """Helper to add node right after head (most recently used)""" + self._add_node_after(self.head, node) + + def _add_node_after(self, old_node, new_node): + """Helper to add node right after old_node""" + setattr(new_node, self.prv, old_node) # new_node.prev = old_node + setattr( + new_node, self.nxt, getattr(old_node, self.nxt) + ) # new_node.next = old_node.next + setattr( + getattr(old_node, self.nxt), self.prv, new_node + ) # old_node.next.prev = new_node + setattr(old_node, self.nxt, new_node) # old_node.next = new_node + + def _remove_node(self, node): + """Helper to remove node from linked list""" + setattr( + getattr(node, self.prv), self.nxt, getattr(node, self.nxt) + ) # node.prev.next = node.next + setattr( + getattr(node, self.nxt), self.prv, getattr(node, self.prv) + ) # node.next.prev = node.prev + + def _get_lru(self) -> Optional[TreeNode]: + """ + Get the least recently used node + """ + if len(self.cache) == 0: + return None + return getattr(self.tail, self.prv) + + def reset_node_mru(self, node): + """ + Move a (existing) node to most recently used position + """ + assert node.id in self.cache, f"Resetting node {node.id=} not in lru list" + assert ( + not self.mamba or node.mamba_value is not None + ), f"Resetting mamba tombstone node in mamba lru list: {node.id=}" + self._remove_node(node) + self._add_node(node) + + def reset_node_and_parents_mru(self, node, root_node): + """ + Move an (existing) node and its parents to most recently used position. Child node is + more recently used than parent node. + """ + prev_node = self.head + while node != root_node: + if not self.mamba or node.mamba_value is not None: + assert ( + node.id in self.cache + ), f"Resetting node {node.id=} not in lru list when resetting node and parents mru" + self._remove_node(node) + self._add_node_after(prev_node, node) + prev_node = node + node = node.parent + + def insert_mru(self, node): + """ + Insert a (new) node as most recently used + """ + assert ( + not self.mamba or node.mamba_value is not None + ), f"Inserting mamba tombstone node in mamba lru list: {node.id=}" + assert ( + node.id not in self.cache + ), f"Inserting node {node.id=} already in lru list, existing node: {self.cache[node.id].id=}" + self.cache[node.id] = node + self._add_node(node) + + def remove_node(self, node: TreeNode): + """ + Remove node from lru list + """ + assert node.id in self.cache, f"Removing node {node.id=} not in lru list" + assert ( + not self.mamba or node.mamba_value is not None + ), f"Removing mamba tombstone node from mamba lru list: {node.id=}" + del self.cache[node.id] + self._remove_node(node) + + def get_lru_no_lock(self) -> Optional[TreeNode]: + """ + Get the least recently used node that is not locked + """ + return self.get_prev_no_lock(self.tail, check_id=False) + + def get_leaf_lru_no_lock(self) -> Optional[TreeNode]: + """ + Get the least recently used leaf node that is not locked + """ + return self.get_prev_leaf_no_lock(self.tail, check_id=False) + + def get_prev_no_lock( + self, node: TreeNode, check_id: bool = True + ) -> Optional[TreeNode]: + """ + Get the previous (i.e. more recently used) node that is not locked + """ + if check_id: + assert ( + node.id in self.cache + ), f"Getting prev of node {node.id=} not in lru list" + x = getattr(node, self.prv) # x = node.prev + while getattr(x, self.lock_ref) > 0: + x = getattr(x, self.prv) # x = x.prev + # if x is the head, it means there is no node in the lru list without lock + if x == self.head: + return None + return x + + def get_prev_leaf_no_lock(self, node: TreeNode, check_id: bool = True): + """ + Get the previous (i.e. more recently used) leaf node that is not locked + """ + if check_id: + assert ( + node.id in self.cache + ), f"Getting prev of node {node.id=} not in lru list" + x = getattr(node, self.prv) # x = node.prev + while getattr(x, self.lock_ref) > 0 or len(x.children) > 0: + x = getattr(x, self.prv) # x = x.prev + # if x is the head, it means there is no leaf node in the lru list without lock + if x == self.head: + return None + return x + + def in_list(self, node: Optional[TreeNode]): + """ + Check if the node is in the lru list + """ + if not node: + return False + return node.id in self.cache + + # Note: this is expensive, only use for debug + def sanity_check_evictable_size(self): + """ + Check the evictable size (i.e. the size of the nodes that are not locked) + """ + node = self.get_lru_no_lock() + evictable_size = 0 + while self.in_list(node): + evictable_size += ( + len(node.value) if not self.mamba else len(node.mamba_value) + ) + node = self.get_prev_no_lock(node) + return evictable_size + + # Note: this is expensive, only use for debug or idle check + def sanity_check(self, tree_cache: "MambaRadixCache"): + """ + Check if the lru list is valid by rebuilding the lru list from the tree, heapifying it, and + checking if the lru list is valid. + """ + try: + if self.mamba: + nodes = tree_cache._collect_nontombstone_nodes() + else: + nodes = tree_cache._collect_all_nodes() + total_nodes = len(nodes) + total_lru = len(self.cache) + # heapify based on last_access_time + heapq.heapify(nodes) + # the root node is not in the lru list + assert len(nodes) == ( + total_lru + (0 if self.mamba else 1) + ), f"len(nodes): {len(nodes)}, total_lru: {total_lru}" + + x_lru = self._get_lru() + while len(nodes): + x = heapq.heappop(nodes) + if x == tree_cache.root_node: + # root node is not in the lru list + continue + assert ( + x == x_lru + ), f"Incorrect LRU list, {self.mamba=}, x: {x.id=} != x_lru: {x_lru.id=}" + assert ( + x_lru.full_lock_ref == 0 + ), f"x_lru should not be locked when idle, {x_lru.full_lock_ref=}, {x_lru.id=}" + assert ( + x_lru.mamba_lock_ref == 0 + ), f"x_lru should not be locked when idle, {x_lru.mamba_lock_ref=}, {x_lru.id=}" + x_lru = getattr(x, self.prv) + + if self.mamba: + evictable_size = tree_cache.mamba_evictable_size() + lru_list_evictable_size = tree_cache.mamba_lru_list_evictable_size() + else: + evictable_size = tree_cache.full_evictable_size() + lru_list_evictable_size = tree_cache.full_lru_list_evictable_size() + + assert ( + evictable_size == lru_list_evictable_size + ), f"{self.mamba=}, total nodes: {total_nodes}, total lru: {total_lru}, evictable size: {evictable_size} != lru list evictable size: {lru_list_evictable_size}" + except Exception as e: + msg = f"Mamba Radix tree sanity check failed, ping @yizhang2077: {e}" + logger.error(msg) + raise Exception(msg) + + +class MambaRadixCache(BasePrefixCache): + def __init__( + self, + req_to_token_pool: HybridReqToTokenPool, + token_to_kv_pool_allocator: TokenToKVPoolAllocator, + page_size: int, + disable: bool = False, + ): + assert isinstance(token_to_kv_pool_allocator, TokenToKVPoolAllocator) + self.req_to_token_pool = req_to_token_pool + self.token_to_kv_pool_allocator = token_to_kv_pool_allocator + + assert page_size == 1, "Only support page_size=1 in mamba radix cache now." + self.page_size = page_size + self.disable = disable + + if self.token_to_kv_pool_allocator: + self.device = self.token_to_kv_pool_allocator.device + else: + self.device = torch.device("cpu") + + self.key_match_fn = _key_match_page_size1 + self.get_child_key_fn = get_child_key + self.reset() + + ##### Public API ##### + + def reset(self) -> None: + self.root_node = TreeNode() + self.root_node.key = [] + self.root_node.value = [] + self.root_node.full_lock_ref = 1 + self.root_node.mamba_lock_ref = 1 + self.full_evictable_size_ = 0 + self.mamba_evictable_size_ = 0 + self.full_protected_size_ = 0 + self.mamba_protected_size_ = 0 + # LRU lists are used to maintain the order of eviction of the nodes in the tree + self.full_lru_list = LRUList(mamba=False) + self.mamba_lru_list = LRUList(mamba=True) + + def match_prefix(self, key: RadixKey, **kwargs) -> MatchResult: + """Find the matching prefix from the radix tree. + Args: + key: A RadixKey contains token IDs to find a matching prefix. + Returns: + A tuple of a tensor of matching prefix token IDs and + the last node that contains the prefix values. Note that + this API can modify the internal state of the Radix tree. + The last node create a new child if the prefix is shorter + than the last node's value. + """ + cow_mamba: bool = kwargs.get("cow_mamba", False) + req: Req = kwargs.get("req", None) + + if self.disable or len(key) == 0: + return MatchResult( + device_indices=torch.empty( + (0,), + dtype=torch.int64, + device=self.device, + ), + last_device_node=self.root_node, + last_host_node=self.root_node, + ) + + value, last_node = self._match_prefix_helper(key) + + # copy mamba state to req local space if cow is true + if cow_mamba and last_node.mamba_value is not None: + assert req.req_pool_idx is None # req_pool_idx is uninitialed + + # for reqs without mamba cache + if req.mamba_pool_idx is None: + dst_index = self.req_to_token_pool.mamba_pool.alloc(1) + # try to alloc again, protect last_node from eviction + if dst_index is None: + self.inc_lock_ref(last_node) + self.evict_mamba(1) + dst_index = self.req_to_token_pool.mamba_pool.alloc(1) + self.dec_lock_ref(last_node) + assert dst_index is not None, "Can not alloc mamba cache" + src_index = last_node.mamba_value + self.req_to_token_pool.mamba_pool.copy_from(src_index, dst_index) + req.mamba_pool_idx = dst_index[0] + else: + src_index = last_node.mamba_value + dst_index = req.mamba_pool_idx.unsqueeze(0) + self.req_to_token_pool.mamba_pool.copy_from(src_index, dst_index) + + if value: + value = torch.cat(value) + else: + value = torch.empty((0,), dtype=torch.int64, device=self.device) + + return MatchResult( + device_indices=value, + last_device_node=last_node, + last_host_node=last_node, + ) + + def insert(self, key: RadixKey, value=None, mamba_value=None) -> Tuple[int, bool]: + if self.disable: + return 0 + + if value is None: + value = torch.tensor([x for x in key.token_ids], dtype=torch.int64) + return self._insert_helper(self.root_node, key, value, mamba_value) + + def cache_finished_req(self, req: Req) -> None: + """Cache request when it finishes.""" + if self.disable: + kv_indices = self.req_to_token_pool.req_to_token[ + req.req_pool_idx, + : len(req.origin_input_ids) + max(len(req.output_ids) - 1, 0), + ] + self.token_to_kv_pool_allocator.free(kv_indices) + self.req_to_token_pool.free(req.req_pool_idx) + return + + token_ids = (req.origin_input_ids + req.output_ids)[:-1] + kv_indices = self.req_to_token_pool.req_to_token[ + req.req_pool_idx, : len(token_ids) + ] + + page_aligned_len = len(kv_indices) + page_aligned_kv_indices = kv_indices.to(dtype=torch.int64, copy=True) + + # Radix Cache takes one ref in memory pool + # insert the token_ids and kv_indices into the radix tree + # Note: the insert function already frees the overlapped kv_indices + mamba_value = ( + self.req_to_token_pool.get_mamba_indices(req.req_pool_idx) + .unsqueeze(-1) + .clone() + ) + + new_prefix_len, mamba_exist = self.insert( + RadixKey(token_ids[:page_aligned_len], req.extra_key), + page_aligned_kv_indices, + mamba_value, + ) + + self.token_to_kv_pool_allocator.free( + kv_indices[len(req.prefix_indices) : new_prefix_len] + ) + + self.req_to_token_pool.free(req.req_pool_idx, free_mamba_cache=mamba_exist) + self.dec_lock_ref(req.last_node) + + def cache_unfinished_req(self, req: Req, chunked=False) -> None: + """Cache request when it is unfinished.""" + if self.disable: + kv_indices = self.req_to_token_pool.req_to_token[ + req.req_pool_idx, : len(req.fill_ids) + ] + # `req.prefix_indices` will be used in `PrefillAdder::add_chunked_req` later + req.prefix_indices = kv_indices + return + + token_ids = req.fill_ids + kv_indices = self.req_to_token_pool.req_to_token[ + req.req_pool_idx, : len(token_ids) + ] + page_aligned_len = len(kv_indices) + page_aligned_kv_indices = kv_indices.to(dtype=torch.int64, copy=True) + page_aligned_token_ids = token_ids[:page_aligned_len] + + mamba_value = self.req_to_token_pool.get_mamba_indices( + req.req_pool_idx + ).unsqueeze(-1) + # radix tree mamba value is forked from req space + mamba_value_forked = self.req_to_token_pool.mamba_pool.fork_from(mamba_value) + + # if alloc mamba cache failed, do evict and alloc again + if mamba_value_forked is None: + self.evict_mamba(1) + mamba_value_forked = self.req_to_token_pool.mamba_pool.fork_from( + mamba_value + ) + assert mamba_value_forked is not None, "Can not alloc mamba cache" + new_prefix_len, mamba_exist = self.insert( + RadixKey(page_aligned_token_ids, req.extra_key), + page_aligned_kv_indices, + mamba_value_forked, + ) + self.token_to_kv_pool_allocator.free( + kv_indices[len(req.prefix_indices) : new_prefix_len] + ) + # there is a mamba cache in radix cache, release it + if mamba_exist: + self.req_to_token_pool.mamba_pool.free(mamba_value_forked) + + # The prefix indices could be updated, reuse it + new_indices, new_last_node, _, _ = self.match_prefix( + RadixKey(page_aligned_token_ids, req.extra_key) + ) + + if not mamba_exist: + assert torch.equal(new_last_node.mamba_value, mamba_value_forked) + + assert len(req.prefix_indices) <= len( + new_indices + ), f"{req.prefix_indices=}, {new_indices=}" + assert new_prefix_len <= len(new_indices), f"{new_prefix_len=}, {new_indices=}" + + self.req_to_token_pool.write( + (req.req_pool_idx, slice(len(req.prefix_indices), len(new_indices))), + new_indices[len(req.prefix_indices) :], + ) + + self.dec_lock_ref(req.last_node) + self.inc_lock_ref(new_last_node) + + # `req.prefix_indices` will be used in `PrefillAdder::add_chunked_req` later + req.prefix_indices = new_indices + req.last_node = new_last_node + + def pretty_print(self) -> None: + self._print_helper(self.root_node, 0) + total_size, total_mamba_size = self._total_size_helper() + print(f"#full_tokens: {total_size}, #mamba_num: {total_mamba_size}") + + def total_size(self) -> Tuple[int, int]: + return self._total_size_helper() + + def _evict_leaf_node( + self, x: TreeNode, is_evict_mamba: bool + ) -> Tuple[int, int, TreeNode, TreeNode]: + assert ( + x.full_lock_ref == 0 and x.mamba_lock_ref == 0 + ), f"evict leaf node invalid with {x.id=} {x.full_lock_ref=} {x.mamba_lock_ref=}" + + assert x.mamba_value is not None, f"leaf node mamba value is not None, {x.id=}" + # 1. a leaf node, free full tokens and mamba + self.token_to_kv_pool_allocator.free(x.value) + full_num_evicted = len(x.value) + self.req_to_token_pool.mamba_pool.free(x.mamba_value) + mamba_num_evicted = len(x.mamba_value) + + # 2. get the next node, update the lru lists + if is_evict_mamba: + x_next = self.mamba_lru_list.get_prev_no_lock(x) + else: + x_next = self.full_lru_list.get_prev_leaf_no_lock(x) + self.full_lru_list.remove_node(x) + self.mamba_lru_list.remove_node(x) + + # 3. delete the leaf node + self._delete_leaf(x) + + # 4. Iteratively delete tombstone leaves to maintain invariant that leaf nodes are not tombstone + x, leaf_full_num_evicted = self._iteratively_delete_tombstone_leaf(x) + full_num_evicted += leaf_full_num_evicted + return full_num_evicted, mamba_num_evicted, x, x_next + + def evict_mamba(self, mamba_num: int) -> None: + if self.disable or mamba_num <= 0: + return + # get the least recently used node that is not locked, doesn't have to be a leaf + x = self.mamba_lru_list.get_lru_no_lock() + mamba_num_evicted = 0 + # evict lru leaf nodes until mamba_num_tokens is reached + while mamba_num_evicted < mamba_num and (self.mamba_lru_list.in_list(x)): + assert x.mamba_value is not None, f"node has no mamba value, {x.id=}" + assert ( + len(x.mamba_value) == 1 + ), f"node has abnormal mamba length, {x.id=}, {len(x.mamba_value)=}" + assert x != self.root_node, f"root node is not evictable, {x.id=}" + assert x.mamba_lock_ref == 0, f"node is in use by mamba kv indices, {x.id=}" + + if len(x.children) > 0: + # 1. an internal node, free mamba tokens. + self.req_to_token_pool.mamba_pool.free(x.mamba_value) + mamba_num_evicted += len(x.mamba_value) + + # 2. get the next node, update the lru lists + x_next = self.mamba_lru_list.get_prev_no_lock(x) + self.mamba_lru_list.remove_node(x) + + # 3. tombstone the node + self._tombstone_internal_node(x) + else: + _, mamba_evicted_delta, _, x_next = self._evict_leaf_node(x, True) + mamba_num_evicted += mamba_evicted_delta + + x = x_next + + def evict(self, full_num_tokens: int) -> None: + if self.disable or full_num_tokens <= 0: + return + + full_num_evicted = 0 + # get the least recently used leaf node that is not locked + x = self.full_lru_list.get_leaf_lru_no_lock() + + while full_num_evicted < full_num_tokens and self.full_lru_list.in_list(x): + assert ( + x != self.root_node + ), f"root node should not exist in full lru list, {x.id=}" + full_num_evicted_delta, _, x, x_next = self._evict_leaf_node(x, False) + full_num_evicted += full_num_evicted_delta + + # if parent has no more children, it is a leaf. It is possible that this node is lru, so + # we need to get the first leaf node in the lru list + if len(x.parent.children) == 0: + x_next = self.full_lru_list.get_leaf_lru_no_lock() + + x = x_next + + def inc_lock_ref(self, node: TreeNode) -> Optional[int]: + """ + Increment the lock reference count for the node. + It locks the full_lock_ref for nodes between the [last node, root), exclusive. + It locks the mamba_lock_ref for current node if its mamba_value exists. + """ + if self.disable: + return None + + # protect mamba value in current node if it exists + if node.mamba_value is not None: + if node.mamba_lock_ref == 0: + self.mamba_evictable_size_ -= len(node.mamba_value) + self.mamba_protected_size_ += len(node.mamba_value) + node.mamba_lock_ref += 1 + + while node != self.root_node: + # lock full from node to root + assert ( + node.full_lock_ref >= 0 + ), f"inc_lock_ref on node with {node.full_lock_ref=}, {node.id=}" + if node.full_lock_ref == 0: + self.full_evictable_size_ -= len(node.value) + self.full_protected_size_ += len(node.value) + node.full_lock_ref += 1 + node = node.parent + return None + + def dec_lock_ref(self, node: TreeNode): + """ + Decrement the lock reference count for the node. + It unlocks the full_lock_ref for nodes between the [last node, root), exclusive. + It unlocks the mamba_lock_ref for current node if its mamba_value exists. + """ + if self.disable: + return + + if node.mamba_value is not None: + assert ( + node.mamba_lock_ref > 0 + ), f"dec_lock_ref on node with {node.mamba_lock_ref=}, {node.id=}" + if node.mamba_lock_ref == 1: + self.mamba_evictable_size_ += len(node.mamba_value) + self.mamba_protected_size_ -= len(node.mamba_value) + node.mamba_lock_ref -= 1 + + while node != self.root_node: + assert ( + node.full_lock_ref > 0 + ), f"dec_lock_ref on node with {node.full_lock_ref=}, {node.id=}" + if node.full_lock_ref == 1: + self.full_evictable_size_ += len(node.value) + self.full_protected_size_ -= len(node.value) + node.full_lock_ref -= 1 + node = node.parent + + def sanity_check(self): + self.full_lru_list.sanity_check(self) + self.mamba_lru_list.sanity_check(self) + + def evictable_size(self) -> Tuple[int, int]: + # Note: use full_evictable_size() and mamba_evictable_size() instead. + raise NotImplementedError + + def full_evictable_size(self) -> int: + return self.full_evictable_size_ + + def mamba_evictable_size(self) -> int: + return self.mamba_evictable_size_ + + # Note: this is expensive, only use for debug + def full_lru_list_evictable_size(self) -> int: + return self.full_lru_list.sanity_check_evictable_size() + + # Note: this is expensive, only use for debug + def mamba_lru_list_evictable_size(self) -> int: + return self.mamba_lru_list.sanity_check_evictable_size() + + def protected_size(self) -> Tuple[int, int]: + # Note: use full_protected_size() and mamba_protected_size() instead. + raise NotImplementedError + + def full_protected_size(self) -> int: + # protected size refers to the size of the full cache that is locked + return self.full_protected_size_ + + def mamba_protected_size(self) -> int: + # protected size refers to the size of the mamba cache that is locked + return self.mamba_protected_size_ + + def all_values_flatten(self) -> torch.Tensor: + values = [] + + def _dfs_helper(node: TreeNode): + for _, child in node.children.items(): + values.append(child.value) + _dfs_helper(child) + + _dfs_helper(self.root_node) + return torch.cat(values) + + ##### Internal Helper Functions ##### + + def _match_prefix_helper( + self, key: RadixKey + ) -> Tuple[List[torch.Tensor], TreeNode]: + """ + Mamba prefix matching helper. It factors in the sliding window size such that + the matched node is guaranteed to either 1. connected to root without mamba tombstone, + or 2. the number of matching tokens from the matched node to the last mamba tombstone + node is greater than or equal to the sliding window size. + """ + node = self.root_node + child_key = self.get_child_key_fn(key) + + value = [] + best_value_len = 0 + best_last_node = node + while len(key) > 0 and child_key in node.children.keys(): + child = node.children[child_key] + # update best_value_len and best_last_node if needed + if node.mamba_value is not None: + best_value_len = len(value) + best_last_node = node + + prefix_len = self.key_match_fn(child.key, key) + if prefix_len < len(child.key): + new_node = self._split_node(child.key, child, prefix_len) + value.append(new_node.value) + node = new_node + break + else: + value.append(child.value) + node = child + key = key[prefix_len:] + + if len(key): + child_key = self.get_child_key_fn(key) + # handle best_value_len and best_last_node, for the case that last node is fully matched + if node.mamba_value is not None: + best_value_len = len(value) + best_last_node = node + + # update time for matched nodes, and make nodes closer to root to be least recently used + # this allows mamba to evict nodes closer to root first + self.full_lru_list.reset_node_and_parents_mru(best_last_node, self.root_node) + self.mamba_lru_list.reset_node_and_parents_mru(best_last_node, self.root_node) + + # This last_access_time is for sanity check, can be deleted after validation in production + cur_time = time.monotonic() + while node: + node.last_access_time = cur_time + cur_time -= 0.0001 + node = node.parent + + return value[:best_value_len], best_last_node + + def _split_node(self, key: RadixKey, child: TreeNode, split_len: int) -> TreeNode: + # new_node -> child + new_node = TreeNode() + new_node.children = {self.get_child_key_fn(key[split_len:]): child} + new_node.parent = child.parent + new_node.mamba_value = None # mamba cache can not be split + new_node.full_lock_ref = child.full_lock_ref + new_node.mamba_lock_ref = 0 + new_node.key = child.key[:split_len] + new_node.value = child.value[:split_len] + + # child time should be later than parent's time for mamba tombstone + child.last_access_time = time.monotonic() + + self.full_lru_list.remove_node(child) + if child.mamba_value is not None: + self.mamba_lru_list.remove_node(child) + child.parent = new_node + child.key = child.key[split_len:] + child.value = child.value[split_len:] + new_node.parent.children[self.get_child_key_fn(key)] = new_node + + # insert the new node and child into the lru lists, insert + # parent first so that parent is after child in the lru list + self.full_lru_list.insert_mru(new_node) + self.full_lru_list.insert_mru(child) + if child.mamba_value is not None: + self.mamba_lru_list.insert_mru(child) + return new_node + + def _insert_helper( + self, + node: TreeNode, + key: RadixKey, + value, + mamba_value, + ) -> Tuple[int, bool]: + # Update the last access time from root to leaf, so that + # mamba will tombstone the node closer to root first + assert mamba_value is not None, "Mamba value should not be None here." + node.last_access_time = time.monotonic() + if node != self.root_node: + self.full_lru_list.reset_node_mru(node) + if node.mamba_value is not None: + self.mamba_lru_list.reset_node_mru(node) + if len(key) == 0: + return 0, True + + child_key = self.get_child_key_fn(key) + + total_prefix_length = 0 + while len(key) > 0 and child_key in node.children.keys(): + node = node.children[child_key] + node.last_access_time = time.monotonic() + self.full_lru_list.reset_node_mru(node) + if node.mamba_value is not None: + self.mamba_lru_list.reset_node_mru(node) + prefix_len = self.key_match_fn(node.key, key) + total_prefix_length += prefix_len + key = key[prefix_len:] + value = value[prefix_len:] + + if prefix_len < len(node.key): + new_node = self._split_node(node.key, node, prefix_len) + node = new_node + + if len(key): + child_key = self.get_child_key_fn(key) + + mamba_value_exist = False + if len(key): + new_node = TreeNode() + new_node.parent = node + new_node.key = key + new_node.value = value + new_node.mamba_value = mamba_value + self.full_lru_list.insert_mru(new_node) + self.full_evictable_size_ += len(value) + self.mamba_evictable_size_ += len(mamba_value) + self.mamba_lru_list.insert_mru(new_node) + node.children[child_key] = new_node + elif node.mamba_value is None: # add for mamba tombstone + node.mamba_value = mamba_value + self.mamba_evictable_size_ += len(mamba_value) + self.mamba_lru_list.insert_mru(node) + else: + mamba_value_exist = True + self.mamba_lru_list.reset_node_mru(node) + + return total_prefix_length, mamba_value_exist + + def _iteratively_delete_tombstone_leaf( + self, node: TreeNode + ) -> Tuple[TreeNode, int]: + full_num_evicted = 0 + while node.parent.mamba_value is None and len(node.parent.children) == 0: + # root node is not evictable + if node.parent == self.root_node: + break + # if locked, means node is in use, skip + if node.parent.full_lock_ref > 0: + break + assert ( + node.parent.mamba_lock_ref == 0 + ), f"tombstone mamba_lock_ref should always be 0, {node.parent.full_lock_ref=}, {node.parent.mamba_lock_ref=}, {node.parent.id=}" + # delete tombstone node evicts full tokens + self.token_to_kv_pool_allocator.free(node.parent.value) + full_num_evicted += len(node.parent.value) + self.full_lru_list.remove_node(node.parent) + self._delete_tombstone_leaf(node.parent) + node = node.parent + + return node, full_num_evicted + + def _delete_leaf(self, node: TreeNode) -> None: + assert ( + node.mamba_value is not None + ), f"Invariant violated: leaf node is a tombstone, {node.id=}" + assert len(node.children) == 0, f"leaf node has children, {node.id=}" + for k, v in node.parent.children.items(): + if v == node: + break + del node.parent.children[k] + self.full_evictable_size_ -= len(node.key) + self.mamba_evictable_size_ -= len(node.mamba_value) + + def _tombstone_internal_node(self, node: TreeNode) -> None: + assert len(node.children) != 0, f"Cannot tombstone a leaf node, {node.id=}" + self.mamba_evictable_size_ -= len(node.mamba_value) + node.mamba_value = None + + def _delete_tombstone_leaf(self, node: TreeNode) -> None: + assert ( + node.mamba_value is None + ), f"Deleting a unexpected non-tombstone leaf node, {node.id=}" + assert len(node.children) == 0, f"leaf node has children, {node.id=}" + for k, v in node.parent.children.items(): + if v == node: + break + del node.parent.children[k] + self.full_evictable_size_ -= len(node.key) + + def _collect_leaves(self) -> List[TreeNode]: + ret_list = [] + stack = [self.root_node] + + while stack: + cur_node = stack.pop() + if len(cur_node.children) == 0: + ret_list.append(cur_node) + else: + stack.extend(cur_node.children.values()) + + return ret_list + + def _collect_nontombstone_nodes(self) -> List[TreeNode]: + ret_list = [] + stack = [self.root_node] + + while stack: + cur_node = stack.pop() + if cur_node.mamba_value is not None: + ret_list.append(cur_node) + stack.extend(cur_node.children.values()) + + return ret_list + + def _collect_all_nodes(self) -> List[TreeNode]: + ret_list = [] + stack = [self.root_node] + while stack: + cur_node = stack.pop() + ret_list.append(cur_node) + stack.extend(cur_node.children.values()) + return ret_list + + def _print_helper(self, node: TreeNode, indent: int) -> None: + """Prints the radix tree in a human-readable format.""" + stack = [(node, indent)] + while stack: + current_node, current_indent = stack.pop() + print( + " " * current_indent, + f"[{current_node.id}]", + len(current_node.key), + f"fr={current_node.full_lock_ref}", + f"mr={current_node.mamba_lock_ref}", + f"fll={self.full_lru_list.in_list(current_node)}", + f"mll={self.mamba_lru_list.in_list(current_node)}", + f"mv={current_node.mamba_value}", + ) + for key, child in current_node.children.items(): + stack.append((child, current_indent + 2)) + + assert key == self.get_child_key_fn( + child.key + ), f"{key=}, {self.get_child_key_fn(child.key)=}" + + def _total_size_helper(self) -> Tuple[int, int]: + total_size = 0 + total_mamba_size = 0 + stack = [self.root_node] + while stack: + current_node = stack.pop() + total_size += len(current_node.value) + if current_node.mamba_value is not None: + total_mamba_size += len(current_node.mamba_value) + for child in current_node.children.values(): + if child.evicted: + continue + stack.append(child) + return total_size, total_mamba_size diff --git a/python/sglang/srt/mem_cache/memory_pool.py b/python/sglang/srt/mem_cache/memory_pool.py index f948ed63619..15d48142ccb 100644 --- a/python/sglang/srt/mem_cache/memory_pool.py +++ b/python/sglang/srt/mem_cache/memory_pool.py @@ -190,6 +190,7 @@ def __init__( ) logger.info( f"Mamba Cache is allocated. " + f"max_mamba_cache_size: {size}, " f"conv_state size: {get_tensor_size_bytes(conv_state) / GB:.2f}GB, " f"ssm_state size: {get_tensor_size_bytes(temporal_state) / GB:.2f}GB " f"intermediate_ssm_state_cache size: {get_tensor_size_bytes(intermediate_ssm_state_cache) / GB:.2f}GB " @@ -199,11 +200,13 @@ def __init__( self.mamba_cache = self.State(conv=conv_state, temporal=temporal_state) logger.info( f"Mamba Cache is allocated. " + f"max_mamba_cache_size: {size}, " f"conv_state size: {get_tensor_size_bytes(conv_state) / GB:.2f}GB, " f"ssm_state size: {get_tensor_size_bytes(temporal_state) / GB:.2f}GB " ) self.size = size - self.free_slots = list(range(size)) + self.device = device + self.free_slots = torch.arange(self.size, dtype=torch.int64, device=self.device) self.mem_usage = self.mamba_cache.mem_usage_bytes() / GB def get_speculative_mamba2_params_all_layers(self) -> SpeculativeState: @@ -216,7 +219,7 @@ def mamba2_layer_cache(self, layer_id: int): def available_size(self): return len(self.free_slots) - def alloc(self, need_size: int) -> Optional[List[int]]: + def alloc(self, need_size: int) -> Optional[torch.Tensor]: if need_size > len(self.free_slots): return None @@ -225,17 +228,30 @@ def alloc(self, need_size: int) -> Optional[List[int]]: return select_index - def free(self, free_index: Union[int, List[int]]): - if isinstance(free_index, (int,)): - self.free_slots.append(free_index) - else: - self.free_slots.extend(free_index) + def free(self, free_index: torch.Tensor): + if free_index.numel() == 0: + return + self.free_slots = torch.cat((self.free_slots, free_index)) self.mamba_cache.conv[:, free_index] = self.mamba_cache.temporal[ :, free_index ] = 0 def clear(self): - self.free_slots = list(range(self.size)) + self.free_slots = torch.arange(self.size, dtype=torch.int64, device=self.device) + + def copy_from(self, src_index: torch.Tensor, dst_index: torch.Tensor): + self.mamba_cache.conv[:, dst_index] = self.mamba_cache.conv[:, src_index] + self.mamba_cache.temporal[:, dst_index] = self.mamba_cache.temporal[ + :, src_index + ] + return + + def fork_from(self, src_index: torch.Tensor) -> Optional[torch.Tensor]: + dst_index = self.alloc(1) + if dst_index == None: + return None + self.copy_from(src_index, dst_index) + return dst_index class HybridReqToTokenPool(ReqToTokenPool): @@ -245,6 +261,7 @@ def __init__( self, *, size: int, + mamba_size: int, max_context_len: int, device: str, enable_memory_saver: bool, @@ -259,7 +276,7 @@ def __init__( ) self.mamba_pool = MambaPool( - size=size, + size=mamba_size, cache_params=cache_params, device=device, speculative_num_draft_tokens=speculative_num_draft_tokens, @@ -271,9 +288,6 @@ def __init__( size, dtype=torch.int32, device=self.device ) - self.rid_to_mamba_index_mapping: Dict[str, int] = {} - self.mamba_index_to_rid_mapping: Dict[int, str] = {} - # For chunk prefill req, we do not need to allocate mamba cache, # We could use allocated mamba cache instead. def alloc( @@ -285,14 +299,14 @@ def alloc( mamba_index = [] for req in reqs: - rid = req.rid - if rid in self.rid_to_mamba_index_mapping: - mid = self.rid_to_mamba_index_mapping[rid] - elif (mid := self.mamba_pool.alloc(1)) is not None: - mid = mid[0] - self.rid_to_mamba_index_mapping[rid] = mid - self.mamba_index_to_rid_mapping[mid] = rid - mamba_index.append(mid) + mid = None + if req.mamba_pool_idx is not None: # for radix cache + mid = req.mamba_pool_idx + else: + mid = self.mamba_pool.alloc(1)[0] + req.mamba_pool_idx = mid + if mid is not None: + mamba_index.append(mid) assert len(select_index) == len( mamba_index ), f"Not enough space for mamba cache, try to increase --max-mamba-cache-size." @@ -313,17 +327,12 @@ def get_speculative_mamba2_params_all_layers(self) -> MambaPool.SpeculativeState # For chunk prefill, we can not free mamba cache, we need use it in the future def free(self, free_index: Union[int, List[int]], free_mamba_cache: bool = True): + if isinstance(free_index, (int,)): + free_index = [free_index] super().free(free_index) if free_mamba_cache: mamba_index = self.req_index_to_mamba_index_mapping[free_index] - mamba_index_list = mamba_index.tolist() - if isinstance(mamba_index_list, int): - mamba_index_list = [mamba_index_list] - self.mamba_pool.free(mamba_index_list) - for mid in mamba_index_list: - rid = self.mamba_index_to_rid_mapping[mid] - self.mamba_index_to_rid_mapping.pop(mid) - self.rid_to_mamba_index_mapping.pop(rid) + self.mamba_pool.free(mamba_index) def clear(self): super().clear() diff --git a/python/sglang/srt/model_executor/model_runner.py b/python/sglang/srt/model_executor/model_runner.py index 5b1b9d22a0a..5992b6231c6 100644 --- a/python/sglang/srt/model_executor/model_runner.py +++ b/python/sglang/srt/model_executor/model_runner.py @@ -184,6 +184,9 @@ def add_mla_attention_backend(backend_name): # Detect stragger ranks in model loading UNBALANCED_MODEL_LOADING_TIMEOUT_S = 300 +# the ratio of mamba cache pool size to max_running_requests, it will be safe when it is larger than 2 (yizhang2077) +MAMBA_CACHE_SIZE_MAX_RUNNING_REQUESTS_RATIO = 3 + logger = logging.getLogger(__name__) @@ -355,27 +358,6 @@ def initialize(self, min_per_gpu_memory: float): if architectures and not any("Llama4" in arch for arch in architectures): self.is_hybrid = self.model_config.is_hybrid = True - if config := self.mambaish_config: - class_name = config.__class__.__name__ - logger.warning(f"{class_name} model detected, disable radix cache") - self.server_args.disable_radix_cache = True - if self.server_args.max_mamba_cache_size is None: - if self.server_args.max_running_requests is not None: - self.server_args.max_mamba_cache_size = ( - self.server_args.max_running_requests - ) - else: - self.server_args.max_mamba_cache_size = 512 - if self.hybrid_gdn_config is not None: - self.server_args.max_mamba_cache_size = ( - self.server_args.max_mamba_cache_size - // ( - self.server_args.dp_size - if self.server_args.enable_dp_attention - else 1 - ) - ) - # For MTP models like DeepSeek-V3 or GLM-4.5, the MTP layer(s) are used separately as draft # models for speculative decoding. In those cases, `num_nextn_predict_layers` is used to # determine the number of layers. @@ -1302,15 +1284,60 @@ def profile_max_num_token(self, total_gpu_memory: int): rest_memory = available_gpu_memory - total_gpu_memory * ( 1 - self.mem_fraction_static ) - if config := self.mambaish_config: - rest_memory -= ( - self.server_args.max_mamba_cache_size - * config.mamba2_cache_params.mamba_cache_per_req - / (1 << 30) - ) + if self.mambaish_config is not None: + rest_memory = self.handle_max_mamba_cache(rest_memory) max_num_token = int(rest_memory * (1 << 30) // cell_size) return max_num_token + def handle_max_mamba_cache(self, total_rest_memory): + config = self.mambaish_config + server_args = self.server_args + assert config is not None + + speculativa_ratio = ( + 0 + if server_args.speculative_num_draft_tokens is None + else server_args.speculative_num_draft_tokens + ) + if ( + server_args.disable_radix_cache + or config.mamba2_cache_params.mamba_cache_per_req == 0 + ): + # with disable radix cache, sets the max_mamba_cache_size based on the max_running_requests + if server_args.max_mamba_cache_size is None: + if server_args.max_running_requests is not None: + server_args.max_mamba_cache_size = server_args.max_running_requests + else: + server_args.max_mamba_cache_size = 512 + else: + # allocate the memory based on the ratio between mamba state memory vs. full kv cache memory + # solve the equations: + # 1. mamba_state_memory + full_kv_cache_memory == total_rest_memory + # 2. mamba_state_memory / full_kv_cache_memory == server_args.mamba_full_memory_ratio + mamba_state_memory_raw = ( + total_rest_memory + * server_args.mamba_full_memory_ratio + / (1 + server_args.mamba_full_memory_ratio) + ) + # calculate the max_mamba_cache_size based on the given total mamba memory + server_args.max_mamba_cache_size = int( + (mamba_state_memory_raw * (1 << 30)) + // config.mamba2_cache_params.mamba_cache_per_req + // (1 + speculativa_ratio) + ) + + if self.hybrid_gdn_config is not None: + server_args.max_mamba_cache_size = server_args.max_mamba_cache_size // ( + server_args.dp_size if server_args.enable_dp_attention else 1 + ) + mamba_state_memory = ( + server_args.max_mamba_cache_size + * config.mamba2_cache_params.mamba_cache_per_req + * (1 + speculativa_ratio) + / (1 << 30) + ) + return total_rest_memory - mamba_state_memory + @property def hybrid_gdn_config(self): config = self.model_config.hf_config @@ -1462,8 +1489,16 @@ def init_memory_pool( ), 4096, ) + if self.mambaish_config is not None: - max_num_reqs = min(max_num_reqs, self.server_args.max_mamba_cache_size) + ratio = ( + MAMBA_CACHE_SIZE_MAX_RUNNING_REQUESTS_RATIO + if not self.server_args.disable_radix_cache + else 1 + ) + max_num_reqs = min( + max_num_reqs, self.server_args.max_mamba_cache_size // ratio + ) if self.spec_algorithm.is_eagle() or self.spec_algorithm.is_standalone(): if self.is_draft_worker: @@ -1546,6 +1581,7 @@ def init_memory_pool( elif config := self.mambaish_config: self.req_to_token_pool = HybridReqToTokenPool( size=max_num_reqs, + mamba_size=self.server_args.max_mamba_cache_size, max_context_len=self.model_config.context_len + extra_max_context_len, device=self.device, diff --git a/python/sglang/srt/server_args.py b/python/sglang/srt/server_args.py index 3b205ef8f57..286fb3cd711 100644 --- a/python/sglang/srt/server_args.py +++ b/python/sglang/srt/server_args.py @@ -354,6 +354,7 @@ class ServerArgs: # Mamba cache max_mamba_cache_size: Optional[int] = None mamba_ssm_dtype: str = "float32" + mamba_full_memory_ratio: float = 0.2 # Hierarchical cache enable_hierarchical_cache: bool = False @@ -2344,6 +2345,12 @@ def add_cli_args(parser: argparse.ArgumentParser): choices=["float32", "bfloat16"], help="The data type of the SSM states in mamba cache.", ) + parser.add_argument( + "--mamba-full-memory-ratio", + type=float, + default=ServerArgs.mamba_full_memory_ratio, + help="The ratio of mamba state memory to full kv cache memory.", + ) # Args for multi-item-scoring parser.add_argument( "--multi-item-scoring-delimiter", diff --git a/test/srt/run_suite.py b/test/srt/run_suite.py index 998e6b0bd0d..2fa2bf76ffc 100644 --- a/test/srt/run_suite.py +++ b/test/srt/run_suite.py @@ -86,6 +86,7 @@ class TestFile: TestFile("test_io_struct.py", 8), TestFile("test_jinja_template_utils.py", 1), TestFile("test_logprobs.py", 55), + TestFile("test_mamba_unittest.py", 4), TestFile("test_metrics.py", 32), TestFile("test_metrics_utils.py", 1), TestFile("test_mla.py", 167), diff --git a/test/srt/test_mamba_unittest.py b/test/srt/test_mamba_unittest.py new file mode 100644 index 00000000000..401eb584f96 --- /dev/null +++ b/test/srt/test_mamba_unittest.py @@ -0,0 +1,332 @@ +import inspect +import os +import unittest + +import torch + +from sglang.srt.configs.mamba_utils import Mamba2CacheParams, Mamba2StateShape +from sglang.srt.managers.schedule_batch import Req +from sglang.srt.mem_cache.allocator import TokenToKVPoolAllocator +from sglang.srt.mem_cache.mamba_radix_cache import MambaRadixCache +from sglang.srt.mem_cache.memory_pool import HybridLinearKVPool, HybridReqToTokenPool +from sglang.srt.mem_cache.radix_cache import RadixKey +from sglang.srt.sampling.sampling_params import SamplingParams + + +class TestMamba(unittest.TestCase): + @classmethod + def setUpClass(cls): + pass + + @classmethod + def tearDownClass(cls): + pass + + def test_hybrid_linear_kv_pool(self): + size = 16 + head_num = 2 + head_dim = 256 + num_layers = 48 + global_interval = 4 + dtype = torch.bfloat16 + device = "cuda" + full_attention_layer_ids = [ + i for i in range(global_interval - 1, num_layers, global_interval) + ] + pool = HybridLinearKVPool( + size=size, + dtype=dtype, + page_size=1, + head_num=head_num, + head_dim=head_dim, + full_attention_layer_ids=full_attention_layer_ids, + enable_kvcache_transpose=False, + device=device, + ) + assert pool._transfer_full_attention_id(global_interval - 1) == 0 + assert pool._transfer_full_attention_id(2 * global_interval - 1) == 1 + with self.assertRaises(ValueError) as context: + pool._transfer_full_attention_id(1) + self.assertIn( + "layer_id=1 not in full attention layers:", str(context.exception) + ) + + def test_mamba_pool(self): + max_num_reqs = 10 + mamba_cache_size = 20 + max_context_len = 128 + device = "cuda" + global_interval = 4 + num_layers = 48 + full_attention_layer_ids = [ + i for i in range(global_interval - 1, num_layers, global_interval) + ] + mamba_layers = [ + i for i in range(num_layers) if i not in full_attention_layer_ids + ] + shape = Mamba2StateShape.create( + tp_world_size=1, + intermediate_size=4096, + n_groups=16, + num_heads=32, + head_dim=128, + state_size=128, + conv_kernel=4, + ) + os.environ["SGLANG_MAMBA_SSM_DTYPE"] = "bfloat16" + mamba2_cache_params = Mamba2CacheParams(shape=shape, layers=mamba_layers) + + req_to_token_pool = HybridReqToTokenPool( + size=max_num_reqs, + mamba_size=mamba_cache_size, + max_context_len=max_context_len, + device=device, + enable_memory_saver=False, + cache_params=mamba2_cache_params, + speculative_num_draft_tokens=3, + ) + + assert req_to_token_pool.available_size() == max_num_reqs + assert req_to_token_pool.mamba_pool.available_size() == mamba_cache_size + + sampling_params = SamplingParams( + temperature=0, + max_new_tokens=1, + ) + req = Req( + rid=0, + origin_input_text="", + origin_input_ids=[], + sampling_params=sampling_params, + ) + + # alloc req + req_index = req_to_token_pool.alloc(1, [req]) + assert req_to_token_pool.available_size() == max_num_reqs - 1 + assert req_to_token_pool.mamba_pool.available_size() == mamba_cache_size - 1 + + # free req + req_to_token_pool.free(req_index) + assert req_to_token_pool.available_size() == max_num_reqs + assert req_to_token_pool.mamba_pool.available_size() == mamba_cache_size + + # alloc req without free mamba cache + req.mamba_pool_idx = None + req_index = req_to_token_pool.alloc(1, [req]) + req_to_token_pool.free(req_index, free_mamba_cache=False) + assert req_to_token_pool.available_size() == max_num_reqs + assert req_to_token_pool.mamba_pool.available_size() == mamba_cache_size - 1 + + # alloc again + req_index = req_to_token_pool.alloc(1, [req]) + assert req_to_token_pool.available_size() == max_num_reqs - 1 + assert req_to_token_pool.mamba_pool.available_size() == mamba_cache_size - 1 + + def test_mamba_radix_cache_1(self): + # kv cache + size = 128 + dtype = torch.bfloat16 + head_num = 2 + head_dim = 256 + num_layers = 48 + global_interval = 4 + max_num_reqs = 10 + mamba_cache_size = 20 + max_context_len = 128 + device = "cuda" + full_attention_layer_ids = [ + i for i in range(global_interval - 1, num_layers, global_interval) + ] + + # mamba + mamba_layers = [ + i for i in range(num_layers) if i not in full_attention_layer_ids + ] + os.environ["SGLANG_MAMBA_SSM_DTYPE"] = "bfloat16" + shape = Mamba2StateShape.create( + tp_world_size=1, + intermediate_size=4096, + n_groups=16, + num_heads=32, + head_dim=128, + state_size=128, + conv_kernel=4, + ) + mamba2_cache_params = Mamba2CacheParams(shape=shape, layers=mamba_layers) + + req_to_token_pool = HybridReqToTokenPool( + size=max_num_reqs, + mamba_size=mamba_cache_size, + max_context_len=max_context_len, + device=device, + enable_memory_saver=False, + cache_params=mamba2_cache_params, + speculative_num_draft_tokens=3, + ) + # setup kv pool + pool = HybridLinearKVPool( + size=size, + dtype=dtype, + page_size=1, + head_num=head_num, + head_dim=head_dim, + full_attention_layer_ids=full_attention_layer_ids, + enable_kvcache_transpose=False, + device=device, + ) + + # setup token to kv pool allocator + allocator = TokenToKVPoolAllocator( + size=size, + dtype=dtype, + device=device, + kvcache=pool, + need_sort=False, + ) + # setup radix cache + tree = MambaRadixCache( + req_to_token_pool=req_to_token_pool, + token_to_kv_pool_allocator=allocator, + page_size=1, + disable=False, + ) + + def make_dummy_req(): + sampling_params = SamplingParams( + temperature=0, + max_new_tokens=1, + ) + req = Req( + rid=0, + origin_input_text="", + origin_input_ids=[], + sampling_params=sampling_params, + ) + req_to_token_pool.alloc(1, reqs=[req]) + return req + + mamba_pool = req_to_token_pool.mamba_pool + # test + print( + f"[Start] allocator mamba available size: {mamba_pool.available_size()}, full available size: {allocator.available_size()}" + ) + req1 = make_dummy_req() + req1_token_ids, req1_kv_indices = [1, 2, 3], allocator.alloc(3) + assert len(req1_token_ids) == len(req1_kv_indices) + print( + f"req1: inserting, req1_token_ids: {req1_token_ids}, req1_kv_indices: {req1_kv_indices}" + ) + prefix_len = tree.insert( + RadixKey(req1_token_ids), req1_kv_indices, req1.mamba_pool_idx.unsqueeze(0) + ) + print( + f"req1: prefix_len: {prefix_len}, allocator mamba available size: {mamba_pool.available_size()}, full available size: {allocator.available_size()}" + ) + req2 = make_dummy_req() + req2_token_ids, req2_kv_indices = [1, 2, 3, 4, 5, 6, 7], allocator.alloc(7) + assert len(req2_token_ids) == len(req2_kv_indices) + print( + f"req2: inserting, req2_token_ids: {req2_token_ids}, req2_kv_indices: {req2_kv_indices}" + ) + prefix_len = tree.insert( + RadixKey(req2_token_ids), req2_kv_indices, req2.mamba_pool_idx.unsqueeze(0) + ) + print( + f"req2: prefix_len: {prefix_len}, allocator mamba available size: {mamba_pool.available_size()}, full available size: {allocator.available_size()}" + ) + + req3 = make_dummy_req() + req3_token_ids, req3_kv_indices = [10, 11, 12], allocator.alloc(3) + assert len(req3_token_ids) == len(req3_kv_indices) + print( + f"req3: inserting, req3_token_ids: {req3_token_ids}, req3_kv_indices: {req3_kv_indices}" + ) + prefix_len = tree.insert( + RadixKey(req3_token_ids), req3_kv_indices, req3.mamba_pool_idx.unsqueeze(0) + ) + print( + f"req3: prefix_len: {prefix_len}, allocator mamba available size: {mamba_pool.available_size()}, full available size: {allocator.available_size()}" + ) + req4 = make_dummy_req() + req4_token_ids, req4_kv_indices = [1, 2, 3, 4, 5, 60, 70], allocator.alloc(7) + assert len(req4_token_ids) == len(req4_kv_indices) + print( + f"req4: inserting, req4_token_ids: {req4_token_ids}, req4_kv_indices: {req4_kv_indices}" + ) + prefix_len = tree.insert( + RadixKey(req4_token_ids), req4_kv_indices, req4.mamba_pool_idx.unsqueeze(0) + ) + print( + f"req4: prefix_len: {prefix_len}, allocator mamba available size: {mamba_pool.available_size()}, full available size: {allocator.available_size()}" + ) + + tree.pretty_print() + full_num_tokens = 1 + print(f"evicting {full_num_tokens} full token") + tree.evict(full_num_tokens=full_num_tokens) + tree.pretty_print() + + mamba_num = 1 + print(f"evicting {mamba_num} mamba") + tree.evict_mamba(mamba_num=mamba_num) + tree.pretty_print() + + req5_token_ids = [1, 2, 3, 4, 5] + result = tree.match_prefix(RadixKey(req5_token_ids)) + kv_indices, last_node = result.device_indices, result.last_device_node + print( + f"req5: token_ids: {req5_token_ids}, matched kv_indices: {kv_indices}, last_node.key: {last_node.key}" + ) + assert len(kv_indices) == 0 + + req6_token_ids = [1, 2, 3, 4, 5, 60, 70] + result = tree.match_prefix(RadixKey(req6_token_ids)) + kv_indices, last_node = result.device_indices, result.last_device_node + print( + f"req6: token_ids: {req6_token_ids}, matched kv_indices: {kv_indices}, last_node.key: {last_node.key}" + ) + assert len(kv_indices) == 7 + assert len(last_node.key) == 2 + + req7_token_ids = [1, 2, 3, 4, 5, 6, 7] + result = tree.match_prefix(RadixKey(req7_token_ids)) + kv_indices, last_node = result.device_indices, result.last_device_node + print( + f"req7: token_ids: {req7_token_ids}, matched kv_indices: {kv_indices}, last_node.key: {last_node.key}" + ) + assert len(kv_indices) == 7 + assert len(last_node.key) == 2 + + mamba_num = 1 + print(f"evicting {mamba_num} mamba") + tree.evict_mamba(mamba_num=mamba_num) + tree.pretty_print() + + req8_token_ids = [1, 2, 3, 4, 5, 60, 70] + result = tree.match_prefix(RadixKey(req8_token_ids)) + kv_indices, last_node = result.device_indices, result.last_device_node + print( + f"req8: token_ids: {req8_token_ids}, matched kv_indices: {kv_indices}, last_node.key: {last_node.key}" + ) + assert len(kv_indices) == 0 + assert len(last_node.key) == 0 + + req9_token_ids = [1, 2, 3, 4, 5, 6, 7] + req9 = make_dummy_req() + result = tree.match_prefix( + RadixKey(req9_token_ids), **({"req": req9, "cow_mamba": True}) + ) + kv_indices, last_node = result.device_indices, result.last_device_node + assert req9.mamba_pool_idx is not None + assert torch.all( + mamba_pool.mamba_cache.conv[:, req9.mamba_pool_idx] + == mamba_pool.mamba_cache.conv[:, last_node.mamba_value] + ) + assert torch.all( + mamba_pool.mamba_cache.temporal[:, req9.mamba_pool_idx] + == mamba_pool.mamba_cache.temporal[:, last_node.mamba_value] + ) + + +if __name__ == "__main__": + unittest.main()