From 4de95c369d7a37d7042463fc5bc583e2b2e2a932 Mon Sep 17 00:00:00 2001 From: yizhang2077 <1109276519@qq.com> Date: Sat, 27 Sep 2025 15:42:36 +0800 Subject: [PATCH 01/14] init support mamba radix cache --- python/sglang/srt/managers/schedule_batch.py | 14 + python/sglang/srt/managers/schedule_policy.py | 12 + python/sglang/srt/managers/scheduler.py | 72 ++ .../srt/managers/scheduler_metrics_mixin.py | 36 + .../sglang/srt/mem_cache/mamba_radix_cache.py | 990 ++++++++++++++++++ python/sglang/srt/mem_cache/memory_pool.py | 61 +- .../sglang/srt/model_executor/model_runner.py | 6 +- 7 files changed, 1160 insertions(+), 31 deletions(-) create mode 100644 python/sglang/srt/mem_cache/mamba_radix_cache.py diff --git a/python/sglang/srt/managers/schedule_batch.py b/python/sglang/srt/managers/schedule_batch.py index f46e160cd2e..f92bcc5c83c 100644 --- a/python/sglang/srt/managers/schedule_batch.py +++ b/python/sglang/srt/managers/schedule_batch.py @@ -61,6 +61,7 @@ ) from sglang.srt.mem_cache.base_prefix_cache import BasePrefixCache from sglang.srt.mem_cache.chunk_cache import ChunkCache, SWAChunkCache +from sglang.srt.mem_cache.mamba_radix_cache import MambaRadixCache from sglang.srt.mem_cache.memory_pool import HybridReqToTokenPool, ReqToTokenPool from sglang.srt.mem_cache.radix_cache import RadixKey from sglang.srt.mem_cache.swa_radix_cache import SWARadixCache @@ -502,6 +503,7 @@ def __init__( # Memory pool info self.req_pool_idx: Optional[int] = None + self.mamba_pool_idx: Optional[torch.Tensor] = None # shape (1) # Check finish self.tokenizer = None @@ -697,6 +699,11 @@ def init_next_round_input( key=RadixKey( token_ids=self.adjust_max_prefix_ids(), extra_key=self.extra_key ), + **( + {"req": self, "cow_mamba": True} + if isinstance(tree_cache, MambaRadixCache) + else {} + ), ) self.extend_input_len = len(self.fill_ids) - len(self.prefix_indices) @@ -1013,6 +1020,13 @@ def is_empty(self): def alloc_req_slots(self, num_reqs: int, reqs: Optional[List[Req]] = None): if isinstance(self.req_to_token_pool, HybridReqToTokenPool): + mamba_available_size = self.req_to_token_pool.mamba_pool.available_size() + if mamba_available_size < num_reqs: + if self.tree_cache is not None and isinstance( + self.tree_cache, MambaRadixCache + ): + mamba_num = max(0, num_reqs - mamba_available_size) + self.tree_cache.evict_mamba(mamba_num) req_pool_indices = self.req_to_token_pool.alloc(num_reqs, reqs) else: req_pool_indices = self.req_to_token_pool.alloc(num_reqs) diff --git a/python/sglang/srt/managers/schedule_policy.py b/python/sglang/srt/managers/schedule_policy.py index 755ac29c8f1..6a229f60f4d 100644 --- a/python/sglang/srt/managers/schedule_policy.py +++ b/python/sglang/srt/managers/schedule_policy.py @@ -27,6 +27,7 @@ from sglang.srt.managers.schedule_batch import Req, ScheduleBatch from sglang.srt.mem_cache.allocator import SWATokenToKVPoolAllocator from sglang.srt.mem_cache.base_prefix_cache import BasePrefixCache +from sglang.srt.mem_cache.mamba_radix_cache import MambaRadixCache from sglang.srt.mem_cache.radix_cache import RadixCache, RadixKey, TreeNode from sglang.srt.server_args import ServerArgs @@ -353,6 +354,7 @@ def __init__( self.is_hybrid = isinstance( self.token_to_kv_pool_allocator, SWATokenToKVPoolAllocator ) + self.is_hybrid_gdn = isinstance(self.tree_cache, MambaRadixCache) self.priority_scheduling_preemption_threshold = ( priority_scheduling_preemption_threshold @@ -376,6 +378,11 @@ def rem_total_tokens(self): self.token_to_kv_pool_allocator.swa_available_size() + self.tree_cache.swa_evictable_size(), ) + elif self.is_hybrid_gdn: + available_and_evictable = ( + self.token_to_kv_pool_allocator.available_size() + + self.tree_cache.full_evictable_size() + ) else: available_and_evictable = ( self.token_to_kv_pool_allocator.available_size() @@ -393,6 +400,11 @@ def cur_rem_tokens(self): self.token_to_kv_pool_allocator.swa_available_size() + self.tree_cache.swa_evictable_size(), ) + elif self.is_hybrid_gdn: + available_and_evictable = ( + self.token_to_kv_pool_allocator.available_size() + + self.tree_cache.full_evictable_size() + ) else: available_and_evictable = ( self.token_to_kv_pool_allocator.available_size() diff --git a/python/sglang/srt/managers/scheduler.py b/python/sglang/srt/managers/scheduler.py index b84608e06a3..109f2a4bcfa 100644 --- a/python/sglang/srt/managers/scheduler.py +++ b/python/sglang/srt/managers/scheduler.py @@ -145,6 +145,7 @@ from sglang.srt.managers.utils import DPBalanceMeta, validate_input_length from sglang.srt.mem_cache.chunk_cache import ChunkCache, SWAChunkCache from sglang.srt.mem_cache.hiradix_cache import HiRadixCache +from sglang.srt.mem_cache.mamba_radix_cache import MambaRadixCache from sglang.srt.mem_cache.radix_cache import RadixCache from sglang.srt.mem_cache.swa_radix_cache import SWARadixCache from sglang.srt.model_executor.forward_batch_info import ForwardMode, PPProxyTensors @@ -433,6 +434,8 @@ def __init__( # Hybrid memory pool self.is_hybrid = self.tp_worker.is_hybrid + self.is_hybrid_gdn = self.tp_worker.worker.model_runner.is_hybrid_gdn + if self.is_hybrid: self.sliding_window_size = self.tp_worker.sliding_window_size self.full_tokens_per_layer, self.swa_tokens_per_layer = ( @@ -729,6 +732,16 @@ def init_memory_pool_and_cache(self): page_size=self.page_size, disable=server_args.disable_radix_cache, ) + elif self.is_hybrid_gdn: + assert ( + self.server_args.disaggregation_mode == "null" + ), "Hybrid GDN mode does not support disaggregation yet" + self.tree_cache = MambaRadixCache( + req_to_token_pool=self.req_to_token_pool, + token_to_kv_pool_allocator=self.token_to_kv_pool_allocator, + page_size=self.page_size, + disable=server_args.disable_radix_cache, + ) elif server_args.enable_lmcache: from sglang.srt.mem_cache.storage.lmcache.lmc_radix_cache import ( LMCRadixCache, @@ -1584,6 +1597,25 @@ def check_memory(self): f"{self.full_tokens_per_layer=}, {full_available_size=}, {full_evictable_size=}, {self.tree_cache.full_protected_size()=}\n" f"{self.swa_tokens_per_layer=}, {swa_available_size=}, {swa_evictable_size=}, {self.tree_cache.swa_protected_size()=}\n" ) + elif self.is_hybrid_gdn and isinstance(self.tree_cache, MambaRadixCache): + ( + full_num_used, + mamba_num_used, + _, + _, + full_available_size, + full_evictable_size, + mamba_available_size, + mamba_evictable_size, + ) = self._get_mamba_token_info() + memory_leak = ( + full_num_used != self.tree_cache.full_protected_size() + or mamba_num_used != self.tree_cache.mamba_protected_size() + ) + token_msg = ( + f"{full_available_size=}, {full_evictable_size=}, {self.token_to_kv_pool_allocator.size=}, {self.tree_cache.full_protected_size()=}\n" + f"{mamba_available_size=}, {mamba_evictable_size=}, {self.req_to_token_pool.mamba_pool.size=}, {self.tree_cache.mamba_protected_size()=}\n" + ) else: _, _, available_size, evictable_size = self._get_token_info() protected_size = self.tree_cache.protected_size() @@ -1634,6 +1666,17 @@ def check_memory(self): ) = self._get_swa_token_info() num_used = max(full_num_used, swa_num_used) token_usage = max(full_token_usage, swa_token_usage) + elif self.is_hybrid_gdn: + ( + num_used, + _, + token_usage, + _, + _, + _, + _, + _, + ) = self._get_mamba_token_info() else: num_used, token_usage, _, _ = self._get_token_info() num_running_reqs = len(self.running_batch.reqs) @@ -1671,6 +1714,35 @@ def _get_token_info(self): token_usage = num_used / self.max_total_num_tokens return num_used, token_usage, available_size, evictable_size + def _get_mamba_token_info(self): + is_radix_tree = isinstance(self.tree_cache, MambaRadixCache) + full_available_size = self.token_to_kv_pool_allocator.available_size() + full_evictable_size = ( + self.tree_cache.full_evictable_size() if is_radix_tree else 0 + ) + mamba_available_size = self.req_to_token_pool.mamba_pool.available_size() + mamba_evictable_size = ( + self.tree_cache.mamba_evictable_size() if is_radix_tree else 0 + ) + full_num_used = self.token_to_kv_pool_allocator.size - ( + full_available_size + full_evictable_size + ) + mamba_num_used = self.req_to_token_pool.mamba_pool.size - ( + mamba_available_size + mamba_evictable_size + ) + full_token_usage = full_num_used / self.token_to_kv_pool_allocator.size + mamba_usage = mamba_num_used / self.req_to_token_pool.mamba_pool.size + return ( + full_num_used, + mamba_num_used, + full_token_usage, + mamba_usage, + full_available_size, + full_evictable_size, + mamba_available_size, + mamba_evictable_size, + ) + def _get_swa_token_info(self): full_available_size = self.token_to_kv_pool_allocator.full_available_size() full_evictable_size = self.tree_cache.full_evictable_size() diff --git a/python/sglang/srt/managers/scheduler_metrics_mixin.py b/python/sglang/srt/managers/scheduler_metrics_mixin.py index 66cdc95bbd3..a5a39c4c0ff 100644 --- a/python/sglang/srt/managers/scheduler_metrics_mixin.py +++ b/python/sglang/srt/managers/scheduler_metrics_mixin.py @@ -105,6 +105,23 @@ def log_prefill_stats( f"full token usage: {full_token_usage:.2f}, " f"swa token usage: {swa_token_usage:.2f}, " ) + elif self.is_hybrid_gdn: + ( + full_num_used, + _, + full_token_usage, + mamba_usage, + _, + _, + _, + _, + ) = self._get_mamba_token_info() + num_used = full_num_used + token_usage = full_token_usage + token_msg = ( + f"full token usage: {full_token_usage:.2f}, " + f"mamba usage: {mamba_usage:.2f}, " + ) else: num_used, token_usage, _, _ = self._get_token_info() token_msg = f"token usage: {token_usage:.2f}, " @@ -187,6 +204,25 @@ def log_decode_stats( f"#swa token: {swa_num_used}, " f"swa token usage: {swa_token_usage:.2f}, " ) + elif self.is_hybrid_gdn: + ( + full_num_used, + mamba_used, + full_token_usage, + mamba_usage, + _, + _, + _, + _, + ) = self._get_mamba_token_info() + num_used = full_num_used + token_usage = full_token_usage + token_msg = ( + f"#full token: {full_num_used}, " + f"full token usage: {full_token_usage:.2f}, " + f"mamba num: {mamba_used}, " + f"mamba usage: {mamba_usage:.2f}, " + ) else: num_used, token_usage, _, _ = self._get_token_info() token_msg = f"#token: {num_used}, " f"token usage: {token_usage:.2f}, " diff --git a/python/sglang/srt/mem_cache/mamba_radix_cache.py b/python/sglang/srt/mem_cache/mamba_radix_cache.py new file mode 100644 index 00000000000..e1ca6d7da4a --- /dev/null +++ b/python/sglang/srt/mem_cache/mamba_radix_cache.py @@ -0,0 +1,990 @@ +from __future__ import annotations + +""" +Copyright 2023-2024 SGLang Team +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" + +""" +The radix tree data structure for managing the hybrid (full and Mamba) KV cache. +""" + +import heapq +import time +from collections import defaultdict +from functools import partial +from typing import TYPE_CHECKING, List, Optional, Tuple + +import torch + +from sglang.srt.mem_cache.allocator import TokenToKVPoolAllocator +from sglang.srt.mem_cache.base_prefix_cache import BasePrefixCache, MatchResult +from sglang.srt.mem_cache.memory_pool import HybridReqToTokenPool +from sglang.srt.mem_cache.radix_cache import ( + RadixKey, + _key_match_page_size1, + _key_match_paged, + get_child_key, +) + +if TYPE_CHECKING: + from sglang.srt.managers.schedule_batch import Req + +import logging + +logger = logging.getLogger(__name__) + + +class TreeNode: + + counter = 0 + + def __init__(self, id: Optional[int] = None): + self.children = defaultdict(TreeNode) + self.parent: TreeNode = None + self.key: RadixKey = None + self.value: Optional[torch.Tensor] = None + self.mamba_value: Optional[torch.Tensor] = None + # invariant: for any node, if mamba_lock_ref is locked, full_lock_ref must be locked; + # if full_lock_ref is locked, mamba_lock_ref doesn't need to be locked. So, + # full_lock_ref is always >= mamba_lock_ref. + # for full_lock, once it is locked, its parent must be locked as well + # for mamba_lock, it only need lock node itself + self.full_lock_ref = 0 + self.mamba_lock_ref = 0 + # last access time is only used for sanity check. LRU is maintained by the lru list. + self.last_access_time = time.monotonic() + + self.hit_count = 0 + # store the host indices of KV cache + self.host_value = None + + # for lru list, invariant: + # 1. prev has greater last_access_time + # 2. next has smaller last_access_time + self.prev = None + self.next = None + self.mamba_prev = None + self.mamba_next = None + + self.id = TreeNode.counter if id is None else id + TreeNode.counter += 1 + + @property + def evicted(self): + return self.value is None + + @property + def backuped(self): + return self.host_value is not None + + def __lt__(self, other: "TreeNode"): + return self.last_access_time < other.last_access_time + + +class LRUList: + def __init__(self, mamba: bool = False): + self.mamba = mamba + if self.mamba: + self.prv = "mamba_prev" + self.nxt = "mamba_next" + self.lock_ref = "mamba_lock_ref" + else: + self.prv = "prev" + self.nxt = "next" + self.lock_ref = "full_lock_ref" + # Initialize dummy head and tail nodes + self.head = TreeNode() # Most recently used side + self.tail = TreeNode() # Least recently used side + setattr(self.head, self.nxt, self.tail) # self.head.next = self.tail + setattr(self.tail, self.prv, self.head) # self.tail.prev = self.head + self.cache = {} + + def _add_node(self, node): + """Helper to add node right after head (most recently used)""" + self._add_node_after(self.head, node) + + def _add_node_after(self, old_node, new_node): + """Helper to add node right after old_node""" + setattr(new_node, self.prv, old_node) # new_node.prev = old_node + setattr( + new_node, self.nxt, getattr(old_node, self.nxt) + ) # new_node.next = old_node.next + setattr( + getattr(old_node, self.nxt), self.prv, new_node + ) # old_node.next.prev = new_node + setattr(old_node, self.nxt, new_node) # old_node.next = new_node + + def _remove_node(self, node): + """Helper to remove node from linked list""" + setattr( + getattr(node, self.prv), self.nxt, getattr(node, self.nxt) + ) # node.prev.next = node.next + setattr( + getattr(node, self.nxt), self.prv, getattr(node, self.prv) + ) # node.next.prev = node.prev + + def _get_lru(self) -> Optional[TreeNode]: + """ + Get the least recently used node + """ + if len(self.cache) == 0: + return None + return getattr(self.tail, self.prv) + + def reset_node_mru(self, node): + """ + Move a (existing) node to most recently used position + """ + assert node.id in self.cache, f"Resetting node {node.id=} not in lru list" + assert ( + not self.mamba or node.mamba_value is not None + ), f"Resetting mamba tombstone node in mamba lru list: {node.id=}" + self._remove_node(node) + self._add_node(node) + + def reset_node_and_parents_mru(self, node, root_node): + """ + Move an (existing) node and its parents to most recently used position. Child node is + more recently used than parent node. + """ + assert not self.mamba + prev_node = self.head + while node != root_node: + assert ( + node.id in self.cache + ), f"Resetting node {node.id=} not in lru list when resetting node and parents mru" + self._remove_node(node) + self._add_node_after(prev_node, node) + prev_node = node + node = node.parent + + def insert_mru(self, node): + """ + Insert a (new) node as most recently used + """ + assert ( + not self.mamba or node.mamba_value is not None + ), f"Inserting mamba tombstone node in mamba lru list: {node.id=}" + assert ( + node.id not in self.cache + ), f"Inserting node {node.id=} already in lru list, existing node: {self.cache[node.id].id=}" + self.cache[node.id] = node + self._add_node(node) + + def remove_node(self, node: TreeNode): + """ + Remove node from lru list + """ + assert node.id in self.cache, f"Removing node {node.id=} not in lru list" + assert ( + not self.mamba or node.mamba_value is not None + ), f"Removing mamba tombstone node from mamba lru list: {node.id=}" + del self.cache[node.id] + self._remove_node(node) + + def get_lru_no_lock(self) -> Optional[TreeNode]: + """ + Get the least recently used node that is not locked + """ + return self.get_prev_no_lock(self.tail, check_id=False) + + def get_leaf_lru_no_lock(self) -> Optional[TreeNode]: + """ + Get the least recently used leaf node that is not locked + """ + return self.get_prev_leaf_no_lock(self.tail, check_id=False) + + def get_prev_no_lock( + self, node: TreeNode, check_id: bool = True + ) -> Optional[TreeNode]: + """ + Get the previous (i.e. more recently used) node that is not locked + """ + if check_id: + assert ( + node.id in self.cache + ), f"Getting prev of node {node.id=} not in lru list" + x = getattr(node, self.prv) # x = node.prev + while getattr(x, self.lock_ref) > 0: + x = getattr(x, self.prv) # x = x.prev + # if x is the head, it means there is no node in the lru list without lock + if x == self.head: + return None + return x + + def get_prev_leaf_no_lock(self, node: TreeNode, check_id: bool = True): + """ + Get the previous (i.e. more recently used) leaf node that is not locked + """ + if check_id: + assert ( + node.id in self.cache + ), f"Getting prev of node {node.id=} not in lru list" + x = getattr(node, self.prv) # x = node.prev + while getattr(x, self.lock_ref) > 0 or len(x.children) > 0: + x = getattr(x, self.prv) # x = x.prev + # if x is the head, it means there is no leaf node in the lru list without lock + if x == self.head: + return None + return x + + def in_list(self, node: Optional[TreeNode]): + """ + Check if the node is in the lru list + """ + if not node: + return False + return node.id in self.cache + + # Note: this is expensive, only use for debug + def sanity_check_evictable_size(self): + """ + Check the evictable size (i.e. the size of the nodes that are not locked) + """ + node = self.get_lru_no_lock() + evictable_size = 0 + while self.in_list(node): + evictable_size += ( + len(node.value) if not self.mamba else len(node.mamba_value) + ) + node = self.get_prev_no_lock(node) + return evictable_size + + # Note: this is expensive, only use for debug or idle check + def sanity_check(self, tree_cache: "MambaRadixCache"): + """ + Check if the lru list is valid by rebuilding the lru list from the tree, heapifying it, and + checking if the lru list is valid. + """ + try: + if self.mamba: + nodes = tree_cache._collect_nontombstone_nodes() + else: + nodes = tree_cache._collect_all_nodes() + total_nodes = len(nodes) + total_lru_plus_1 = len(self.cache) + 1 + # heapify based on last_access_time + heapq.heapify(nodes) + # the root node is not in the lru list + assert ( + len(nodes) == len(self.cache) + 1 + ), f"len(nodes): {len(nodes)} != len(self.cache) + 1: {len(self.cache) + 1}" + + x_lru = self._get_lru() + while len(nodes): + x = heapq.heappop(nodes) + if x == tree_cache.root_node: + # root node is not in the lru list + continue + assert ( + x == x_lru + ), f"Incorrect LRU list, {self.mamba=}, x: {x.id=} != x_lru: {x_lru.id=}" + assert ( + x_lru.full_lock_ref == 0 + ), f"x_lru should not be locked when idle, {x_lru.full_lock_ref=}, {x_lru.id=}" + assert ( + x_lru.mamba_lock_ref == 0 + ), f"x_lru should not be locked when idle, {x_lru.mamba_lock_ref=}, {x_lru.id=}" + x_lru = getattr(x, self.prv) + + if self.mamba: + evictable_size = tree_cache.mamba_evictable_size() + lru_list_evictable_size = tree_cache.mamba_lru_list_evictable_size() + else: + evictable_size = tree_cache.full_evictable_size() + lru_list_evictable_size = tree_cache.full_lru_list_evictable_size() + + assert ( + evictable_size == lru_list_evictable_size + ), f"{self.mamba=}, total nodes: {total_nodes}, total lru plus 1: {total_lru_plus_1}, evictable size: {evictable_size} != lru list evictable size: {lru_list_evictable_size}" + except Exception as e: + msg = f"Mamba Radix tree sanity check failed, ping @yizhang2077: {e}" + logger.error(msg) + raise Exception(msg) + + +class MambaRadixCache(BasePrefixCache): + def __init__( + self, + req_to_token_pool: HybridReqToTokenPool, + token_to_kv_pool_allocator: TokenToKVPoolAllocator, + page_size: int, + disable: bool = False, + ): + assert isinstance(token_to_kv_pool_allocator, TokenToKVPoolAllocator) + self.req_to_token_pool = req_to_token_pool + self.token_to_kv_pool_allocator = token_to_kv_pool_allocator + + assert page_size == 1, "Only support page_size=1 in mamba radix cache now." + self.page_size = page_size + self.disable = disable + + if self.token_to_kv_pool_allocator: + self.device = self.token_to_kv_pool_allocator.device + else: + self.device = torch.device("cpu") + + self.key_match_fn = _key_match_page_size1 + self.get_child_key_fn = get_child_key + self.reset() + + ##### Public API ##### + + def reset(self) -> None: + self.root_node = TreeNode() + self.root_node.key = [] + self.root_node.value = [] + self.root_node.full_lock_ref = 1 + self.root_node.mamba_lock_ref = 1 + self.full_evictable_size_ = 0 + self.mamba_evictable_size_ = 0 + self.full_protected_size_ = 0 + self.mamba_protected_size_ = 0 + # LRU lists are used to maintain the order of eviction of the nodes in the tree + self.full_lru_list = LRUList(mamba=False) + self.mamba_lru_list = LRUList(mamba=True) + + def match_prefix(self, key: RadixKey, **kwargs) -> MatchResult: + """Find the matching prefix from the radix tree. + Args: + key: A RadixKey contains token IDs to find a matching prefix. + Returns: + A tuple of a tensor of matching prefix token IDs and + the last node that contains the prefix values. Note that + this API can modify the internal state of the Radix tree. + The last node create a new child if the prefix is shorter + than the last node's value. + """ + cow_mamba: bool = kwargs.get("cow_mamba", False) + req: Req = kwargs.get("req", None) + + if self.disable or len(key) == 0: + return MatchResult( + device_indices=torch.empty( + (0,), + dtype=torch.int64, + device=self.device, + ), + last_device_node=self.root_node, + last_host_node=self.root_node, + ) + + value, last_node = self._match_prefix_helper(key) + + # copy mamba state to req local space if cow is true + if cow_mamba and last_node.mamba_value is not None: + assert req.req_pool_idx is None # req_pool_idx is uninitialed + dst_index = self.req_to_token_pool.mamba_pool.alloc(1) + # try to alloc again, protect last_node from eviction + if dst_index is None: + self.inc_lock_ref(last_node) + self.evict_mamba(1) + dst_index = self.req_to_token_pool.mamba_pool.alloc(1) + self.dec_lock_ref(last_node) + assert dst_index is not None, "Can not alloc mamba cache" + src_index = last_node.mamba_value + self.req_to_token_pool.mamba_pool.copy_from(src_index, dst_index) + req.mamba_pool_idx = dst_index[0] + if value: + value = torch.cat(value) + else: + value = torch.empty((0,), dtype=torch.int64, device=self.device) + + return MatchResult( + device_indices=value, + last_device_node=last_node, + last_host_node=last_node, + ) + + def insert(self, key: RadixKey, value=None, mamba_value=None) -> Tuple[int, bool]: + if self.disable: + return 0 + + if value is None: + value = torch.tensor([x for x in key.token_ids], dtype=torch.int64) + return self._insert_helper(self.root_node, key, value, mamba_value) + + def cache_finished_req(self, req: Req) -> None: + """Cache request when it finishes.""" + if self.disable: + kv_indices = self.req_to_token_pool.req_to_token[ + req.req_pool_idx, + : len(req.origin_input_ids) + max(len(req.output_ids) - 1, 0), + ] + self.token_to_kv_pool_allocator.free(kv_indices) + self.req_to_token_pool.free(req.req_pool_idx) + return + + token_ids = (req.origin_input_ids + req.output_ids)[:-1] + kv_indices = self.req_to_token_pool.req_to_token[ + req.req_pool_idx, : len(token_ids) + ] + + page_aligned_len = len(kv_indices) + page_aligned_kv_indices = kv_indices.clone() + + # Radix Cache takes one ref in memory pool + # insert the token_ids and kv_indices into the radix tree + # Note: the insert function already frees the overlapped kv_indices + mamba_value = self.req_to_token_pool.get_mamba_indices( + req.req_pool_idx + ).unsqueeze(-1) + # radix tree mamba value is forked from req space + mamba_value_forked = self.req_to_token_pool.mamba_pool.fork_from(mamba_value) + # if alloc mamba cache failed, do evict and alloc again + if mamba_value_forked is None: + self.evict_mamba(1) + mamba_value_forked = self.req_to_token_pool.mamba_pool.fork_from( + mamba_value + ) + assert mamba_value_forked is not None, "Can not alloc mamba cache" + + new_prefix_len, mamba_exist = self.insert( + RadixKey(token_ids[:page_aligned_len], req.extra_key), + page_aligned_kv_indices, + mamba_value_forked, + ) + + self.token_to_kv_pool_allocator.free( + kv_indices[len(req.prefix_indices) : new_prefix_len] + ) + + self.req_to_token_pool.free(req.req_pool_idx) + # there is a mamba cache in radix cache, release it + if mamba_exist: + self.req_to_token_pool.mamba_pool.free(mamba_value_forked) + self.dec_lock_ref(req.last_node) + + def cache_unfinished_req(self, req: Req, chunked=False) -> None: + """Cache request when it is unfinished.""" + if self.disable: + kv_indices = self.req_to_token_pool.req_to_token[ + req.req_pool_idx, : len(req.fill_ids) + ] + # `req.prefix_indices` will be used in `PrefillAdder::add_chunked_req` later + req.prefix_indices = kv_indices + return + + token_ids = req.fill_ids + kv_indices = self.req_to_token_pool.req_to_token[ + req.req_pool_idx, : len(token_ids) + ] + + # Radix Cache takes one ref in memory pool + # Note: the insert function already frees the overlapped kv_indices + mamba_value = self.req_to_token_pool.get_mamba_indices( + req.req_pool_idx + ).unsqueeze(-1) + # radix tree mamba value is forked from req space + mamba_value_forked = self.req_to_token_pool.mamba_pool.fork_from(mamba_value) + + # if alloc mamba cache failed, do evict and alloc again + if mamba_value_forked is None: + self.evict_mamba(1) + mamba_value_forked = self.req_to_token_pool.mamba_pool.fork_from( + mamba_value + ) + assert mamba_value_forked is not None, "Can not alloc mamba cache" + new_prefix_len, mamba_exist = self.insert( + RadixKey(token_ids, req.extra_key), + kv_indices, + mamba_value_forked, + ) + self.token_to_kv_pool_allocator.free( + kv_indices[len(req.prefix_indices) : new_prefix_len] + ) + # there is a mamba cache in radix cache, release it + if mamba_exist: + self.req_to_token_pool.mamba_pool.free(mamba_value_forked) + + # The prefix indices could be updated, reuse it + new_indices, new_last_node, _, _ = self.match_prefix( + RadixKey(token_ids, req.extra_key) + ) + + if not mamba_exist: + assert new_last_node.mamba_value == mamba_value_forked + assert len(req.prefix_indices) <= len( + new_indices + ), f"{req.prefix_indices=}, {new_indices=}" + assert new_prefix_len <= len(new_indices), f"{new_prefix_len=}, {new_indices=}" + + self.req_to_token_pool.write( + (req.req_pool_idx, slice(len(req.prefix_indices), len(new_indices))), + new_indices[len(req.prefix_indices) :], + ) + + self.dec_lock_ref(req.last_node) + self.inc_lock_ref(new_last_node) + + # `req.prefix_indices` will be used in `PrefillAdder::add_chunked_req` later + req.prefix_indices = new_indices + req.last_node = new_last_node + + def pretty_print(self) -> None: + self._print_helper(self.root_node, 0) + total_size, total_mamba_size = self._total_size_helper() + print(f"#full_tokens: {total_size}, #mamba_num: {total_mamba_size}") + + def total_size(self) -> Tuple[int, int]: + return self._total_size_helper() + + def evict_mamba(self, mamba_num: int) -> None: + if self.disable or mamba_num <= 0: + return + # get the least recently used node that is not locked, doesn't have to be a leaf + x = self.mamba_lru_list.get_lru_no_lock() + mamba_num_evicted = 0 + # evict lru leaf nodes until mamba_num_tokens is reached + while mamba_num_evicted < mamba_num and (self.mamba_lru_list.in_list(x)): + assert x.mamba_value is not None, f"node has no mamba value, {x.id=}" + assert x != self.root_node, f"root node is not evictable, {x.id=}" + assert x.mamba_lock_ref == 0, f"node is in use by mamba kv indices, {x.id=}" + + if len(x.children) > 0: + # 1. an internal node, free mamba tokens. + self.req_to_token_pool.mamba_pool.free(x.mamba_value) + mamba_num_evicted += len(x.mamba_value) + + # 2. get the next node, update the lru lists + x_next = self.mamba_lru_list.get_prev_no_lock(x) + self.mamba_lru_list.remove_node(x) + + # 3. tombstone the node + self._tombstone_internal_node(x) + else: + assert ( + x.full_lock_ref == 0 + ), f"leaf node with full lock must also have mamba lock, {x.id=} {x.full_lock_ref=}" + # 1. a leaf node, free full and mamba tokens + self.token_to_kv_pool_allocator.free(x.value) + self.req_to_token_pool.mamba_pool.free(x.mamba_value) + mamba_num_evicted += len(x.mamba_value) + + # 2. get the next node, update the lru lists + x_next = self.mamba_lru_list.get_prev_no_lock(x) + self.full_lru_list.remove_node(x) + self.mamba_lru_list.remove_node(x) + + # 3. delete the leaf node + self._delete_leaf(x) + + # 4. Iteratively delete tombstone leaves to maintain invariant that leaf nodes are not tombstone + self._iteratively_delete_tombstone_leaf(x) + + x = x_next + + def evict(self, full_num_tokens: int) -> None: + if self.disable or full_num_tokens <= 0: + return + + full_num_evicted = 0 + # get the least recently used leaf node that is not locked + x = self.full_lru_list.get_leaf_lru_no_lock() + + while full_num_evicted < full_num_tokens and self.full_lru_list.in_list(x): + assert ( + x != self.root_node + ), f"root node should not exist in full lru list, {x.id=}" + assert x.full_lock_ref == 0, f"node is in use, {x.id=}" + + # 1. free node kv indices, evict full and mamba tokens + self.token_to_kv_pool_allocator.free(x.value) + full_num_evicted += len(x.value) + self.req_to_token_pool.mamba_pool.free(x.mamba_value) + mamba_num_evicted += len(x.mamba_value) + + # 2. get the next leaf, update the lru lists + x_next = self.full_lru_list.get_prev_leaf_no_lock(x) + self.full_lru_list.remove_node(x) + self.mamba_lru_list.remove_node(x) + + # 3. delete the leaf node + self._delete_leaf(x) + + # 4. Iteratively delete tombstone leaves to maintain invariant that leaf nodes are not tombstone + x, leaf_full_num_evicted = self._iteratively_delete_tombstone_leaf(x) + full_num_evicted += leaf_full_num_evicted + + # 5. if parent has no more children, it is a leaf. It is possible that this node is lru, so + # we need to get the first leaf node in the lru list + if len(x.parent.children) == 0: + x_next = self.full_lru_list.get_leaf_lru_no_lock() + + x = x_next + + def inc_lock_ref(self, node: TreeNode) -> Optional[int]: + """ + Increment the lock reference count for the node. + It locks the full_lock_ref for nodes between the [last node, root), exclusive. + It locks the mamba_lock_ref for current node if its mamba_value exists. + """ + if self.disable: + return None + + # protect mamba value in current node if it exists + if node.mamba_value is not None: + if node.mamba_lock_ref == 0: + self.mamba_evictable_size_ -= len(node.mamba_value) + self.mamba_protected_size_ += len(node.mamba_value) + node.mamba_lock_ref += 1 + + while node != self.root_node: + # lock full from node to root + assert ( + node.full_lock_ref >= 0 + ), f"inc_lock_ref on node with {node.full_lock_ref=}, {node.id=}" + if node.full_lock_ref == 0: + self.full_evictable_size_ -= len(node.value) + self.full_protected_size_ += len(node.value) + node.full_lock_ref += 1 + node = node.parent + return None + + def dec_lock_ref(self, node: TreeNode): + """ + Decrement the lock reference count for the node. + It unlocks the full_lock_ref for nodes between the [last node, root), exclusive. + It unlocks the mamba_lock_ref for current node if its mamba_value exists. + """ + if self.disable: + return + + if node.mamba_value is not None: + assert ( + node.mamba_lock_ref > 0 + ), f"dec_lock_ref on node with {node.mamba_lock_ref=}, {node.id=}" + if node.mamba_lock_ref == 1: + self.mamba_evictable_size_ += len(node.mamba_value) + self.mamba_protected_size_ -= len(node.mamba_value) + node.mamba_lock_ref -= 1 + + while node != self.root_node: + assert ( + node.full_lock_ref > 0 + ), f"dec_lock_ref on node with {node.full_lock_ref=}, {node.id=}" + if node.full_lock_ref == 1: + self.full_evictable_size_ += len(node.value) + self.full_protected_size_ -= len(node.value) + node.full_lock_ref -= 1 + node = node.parent + + def sanity_check(self): + self.full_lru_list.sanity_check(self) + self.mamba_lru_list.sanity_check(self) + + def evictable_size(self) -> Tuple[int, int]: + # Note: use full_evictable_size() and mamba_evictable_size() instead. + raise NotImplementedError + + def full_evictable_size(self) -> int: + return self.full_evictable_size_ + + def mamba_evictable_size(self) -> int: + return self.mamba_evictable_size_ + + # Note: this is expensive, only use for debug + def full_lru_list_evictable_size(self) -> int: + return self.full_lru_list.sanity_check_evictable_size() + + # Note: this is expensive, only use for debug + def mamba_lru_list_evictable_size(self) -> int: + return self.mamba_lru_list.sanity_check_evictable_size() + + def protected_size(self) -> Tuple[int, int]: + # Note: use full_protected_size() and mamba_protected_size() instead. + raise NotImplementedError + + def full_protected_size(self) -> int: + # protected size refers to the size of the full cache that is locked + return self.full_protected_size_ + + def mamba_protected_size(self) -> int: + # protected size refers to the size of the mamba cache that is locked + return self.mamba_protected_size_ + + def all_values_flatten(self) -> torch.Tensor: + values = [] + + def _dfs_helper(node: TreeNode): + for _, child in node.children.items(): + values.append(child.value) + _dfs_helper(child) + + _dfs_helper(self.root_node) + return torch.cat(values) + + ##### Internal Helper Functions ##### + + def _match_prefix_helper( + self, key: RadixKey + ) -> Tuple[List[torch.Tensor], TreeNode]: + """ + Mamba prefix matching helper. It factors in the sliding window size such that + the matched node is guaranteed to either 1. connected to root without mamba tombstone, + or 2. the number of matching tokens from the matched node to the last mamba tombstone + node is greater than or equal to the sliding window size. + """ + node = self.root_node + child_key = self.get_child_key_fn(key) + + value = [] + best_value_len = 0 + best_last_node = node + while len(key) > 0 and child_key in node.children.keys(): + child = node.children[child_key] + # update best_value_len and best_last_node if needed + if node.mamba_value is not None: + best_value_len = len(value) + best_last_node = node + + prefix_len = self.key_match_fn(child.key, key) + if prefix_len < len(child.key): + new_node = self._split_node(child.key, child, prefix_len) + value.append(new_node.value) + node = new_node + break + else: + value.append(child.value) + node = child + key = key[prefix_len:] + + if len(key): + child_key = self.get_child_key_fn(key) + # handle best_value_len and best_last_node, for the case that last node is fully matched + if node.mamba_value is not None: + best_value_len = len(value) + best_last_node = node + + # update time for matched nodes, and make nodes closer to root to be least recently used + # this allows mamba to evict nodes closer to root first + self.full_lru_list.reset_node_and_parents_mru(best_last_node, self.root_node) + if best_last_node.mamba_value is not None: + self.mamba_lru_list.reset_node_mru(best_last_node) + + # This last_access_time is for sanity check, can be deleted after validation in production + cur_time = time.monotonic() + while node: + node.last_access_time = cur_time + cur_time -= 0.0001 + node = node.parent + + return value[:best_value_len], best_last_node + + def _split_node(self, key: RadixKey, child: TreeNode, split_len: int) -> TreeNode: + # new_node -> child + new_node = TreeNode() + new_node.children = {self.get_child_key_fn(key[split_len:]): child} + new_node.parent = child.parent + new_node.mamba_value = None # mamba cache can not be splitted + new_node.full_lock_ref = child.full_lock_ref + new_node.mamba_lock_ref = 0 + new_node.key = child.key[:split_len] + new_node.value = child.value[:split_len] + + # child time should be later than parent's time for mamba tombstone + child.last_access_time = time.monotonic() + + self.full_lru_list.remove_node(child) + child.parent = new_node + child.key = child.key[split_len:] + child.value = child.value[split_len:] + new_node.parent.children[self.get_child_key_fn(key)] = new_node + + # insert the new node and child into the lru lists, insert + # parent first so that parent is after child in the lru list + self.full_lru_list.insert_mru(new_node) + self.full_lru_list.insert_mru(child) + return new_node + + def _insert_helper( + self, + node: TreeNode, + key: RadixKey, + value, + mamba_value, + ) -> Tuple[int, bool]: + # Update the last access time from root to leaf, so that + # mamba will tombstone the node closer to root first + node.last_access_time = time.monotonic() + if node != self.root_node: + self.full_lru_list.reset_node_mru(node) + if len(key) == 0: + return 0, True + + child_key = self.get_child_key_fn(key) + + total_prefix_length = 0 + while len(key) > 0 and child_key in node.children.keys(): + node = node.children[child_key] + node.last_access_time = time.monotonic() + self.full_lru_list.reset_node_mru(node) + prefix_len = self.key_match_fn(node.key, key) + total_prefix_length += prefix_len + key = key[prefix_len:] + value = value[prefix_len:] + + if prefix_len < len(node.key): + new_node = self._split_node(node.key, node, prefix_len) + node = new_node + + if len(key): + child_key = self.get_child_key_fn(key) + + mamba_value_exist = False + if len(key): + new_node = TreeNode() + new_node.parent = node + new_node.key = key + new_node.value = value + new_node.mamba_value = mamba_value + self.full_lru_list.insert_mru(new_node) + self.full_evictable_size_ += len(value) + if mamba_value is not None: + self.mamba_evictable_size_ += len(mamba_value) + self.mamba_lru_list.insert_mru(new_node) + node.children[child_key] = new_node + elif node.mamba_value is None: # add for mamba tombstone + node.mamba_value = mamba_value + if mamba_value is not None: + self.mamba_evictable_size_ += len(mamba_value) + self.mamba_lru_list.insert_mru(node) + else: + mamba_value_exist = True + + return total_prefix_length, mamba_value_exist + + def _iteratively_delete_tombstone_leaf( + self, node: TreeNode + ) -> Tuple[TreeNode, int]: + full_num_evicted = 0 + while node.parent.mamba_value is None and len(node.parent.children) == 0: + # root node is not evictable + if node.parent == self.root_node: + break + # if locked, means node is in use, skip + if node.parent.full_lock_ref > 0: + break + assert ( + node.parent.mamba_lock_ref == 0 + ), f"tombstone mamba_lock_ref should always be 0, {node.parent.full_lock_ref=}, {node.parent.mamba_lock_ref=}, {node.parent.id=}" + # delete tombstone node evicts full tokens + self.token_to_kv_pool_allocator.free(node.parent.value) + full_num_evicted += len(node.parent.value) + self.full_lru_list.remove_node(node.parent) + self._delete_tombstone_leaf(node.parent) + node = node.parent + + return node, full_num_evicted + + def _delete_leaf(self, node: TreeNode) -> None: + assert ( + node.mamba_value is not None + ), f"Invariant violated: leaf node is a tombstone, {node.id=}" + assert len(node.children) == 0, f"leaf node has children, {node.id=}" + for k, v in node.parent.children.items(): + if v == node: + break + del node.parent.children[k] + self.full_evictable_size_ -= len(node.key) + self.mamba_evictable_size_ -= len(node.mamba_value) + + def _tombstone_internal_node(self, node: TreeNode) -> None: + assert len(node.children) != 0, f"Cannot tombstone a leaf node, {node.id=}" + self.mamba_evictable_size_ -= ( + len(node.mamba_value) if node.mamba_value is not None else 0 + ) + node.mamba_value = None + + def _delete_tombstone_leaf(self, node: TreeNode) -> None: + assert ( + node.mamba_value is None + ), f"Deleting a unexpected non-tombstone leaf node, {node.id=}" + assert len(node.children) == 0, f"leaf node has children, {node.id=}" + for k, v in node.parent.children.items(): + if v == node: + break + del node.parent.children[k] + self.full_evictable_size_ -= len(node.key) + + def _collect_leaves(self) -> List[TreeNode]: + ret_list = [] + stack = [self.root_node] + + while stack: + cur_node = stack.pop() + if len(cur_node.children) == 0: + ret_list.append(cur_node) + else: + stack.extend(cur_node.children.values()) + + return ret_list + + def _collect_nontombstone_nodes(self) -> List[TreeNode]: + ret_list = [] + stack = [self.root_node] + + while stack: + cur_node = stack.pop() + if cur_node.mamba_value is not None: + ret_list.append(cur_node) + stack.extend(cur_node.children.values()) + + return ret_list + + def _collect_all_nodes(self) -> List[TreeNode]: + ret_list = [] + stack = [self.root_node] + while stack: + cur_node = stack.pop() + ret_list.append(cur_node) + stack.extend(cur_node.children.values()) + return ret_list + + def _print_helper(self, node: TreeNode, indent: int) -> None: + """Prints the radix tree in a human-readable format.""" + stack = [(node, indent)] + while stack: + current_node, current_indent = stack.pop() + print( + " " * current_indent, + f"[{current_node.id}]", + len(current_node.key), + f"fr={current_node.full_lock_ref}", + f"mr={current_node.mamba_lock_ref}", + f"fll={self.full_lru_list.in_list(current_node)}", + f"mll={self.mamba_lru_list.in_list(current_node)}", + f"mv={current_node.mamba_value}", + ) + for key, child in current_node.children.items(): + stack.append((child, current_indent + 2)) + + assert key == self.get_child_key_fn( + child.key + ), f"{key=}, {self.get_child_key_fn(child.key)=}" + + def _total_size_helper(self) -> Tuple[int, int]: + total_size = 0 + total_mamba_size = 0 + stack = [self.root_node] + while stack: + current_node = stack.pop() + total_size += len(current_node.value) + if current_node.mamba_value is not None: + total_mamba_size += len(current_node.mamba_value) + for child in current_node.children.values(): + if child.evicted: + continue + stack.append(child) + return total_size, total_mamba_size diff --git a/python/sglang/srt/mem_cache/memory_pool.py b/python/sglang/srt/mem_cache/memory_pool.py index 5b0f8a7141c..93eec3a6bcd 100644 --- a/python/sglang/srt/mem_cache/memory_pool.py +++ b/python/sglang/srt/mem_cache/memory_pool.py @@ -177,7 +177,8 @@ def __init__( f"ssm_state size: {get_tensor_size_bytes(temporal_state) / GB:.2f}GB " ) self.size = size - self.free_slots = list(range(size)) + self.device = device + self.free_slots = torch.arange(self.size, dtype=torch.int64, device=self.device) self.mem_usage = self.get_mamba_size() / GB def get_mamba_params_all_layers(self): @@ -192,7 +193,7 @@ def get_mamba_size(self): def available_size(self): return len(self.free_slots) - def alloc(self, need_size: int) -> Optional[List[int]]: + def alloc(self, need_size: int) -> Optional[torch.Tensor]: if need_size > len(self.free_slots): return None @@ -201,15 +202,26 @@ def alloc(self, need_size: int) -> Optional[List[int]]: return select_index - def free(self, free_index: Union[int, List[int]]): - if isinstance(free_index, (int,)): - self.free_slots.append(free_index) - else: - self.free_slots.extend(free_index) + def free(self, free_index: torch.Tensor): + if free_index.numel() == 0: + return + self.free_slots = torch.cat((self.free_slots, free_index)) self.mamba_cache[0][:, free_index] = self.mamba_cache[1][:, free_index] = 0 def clear(self): - self.free_slots = list(range(self.size)) + self.free_slots = torch.arange(self.size, dtype=torch.int64, device=self.device) + + def copy_from(self, src_index: torch.Tensor, dst_index: torch.Tensor): + self.mamba_cache[0][:, dst_index] = self.mamba_cache[0][:, src_index] + self.mamba_cache[1][:, dst_index] = self.mamba_cache[1][:, src_index] + return + + def fork_from(self, src_index: torch.Tensor) -> Optional[int]: + dst_index = self.alloc(1) + if dst_index == None: + return None + self.copy_from(src_index, dst_index) + return dst_index class HybridReqToTokenPool(ReqToTokenPool): @@ -218,6 +230,7 @@ class HybridReqToTokenPool(ReqToTokenPool): def __init__( self, size: int, + mamba_size: int, max_context_len: int, device: str, enable_memory_saver: bool, @@ -236,7 +249,7 @@ def __init__( ) self.mamba_pool = MambaPool( - size, + mamba_size, conv_dtype, ssm_dtype, len(mamba_layers), @@ -252,9 +265,6 @@ def __init__( size, dtype=torch.int32, device=self.device ) - self.rid_to_mamba_index_mapping: Dict[str, int] = {} - self.mamba_index_to_rid_mapping: Dict[int, str] = {} - # For chunk prefill req, we do not need to allocate mamba cache, # We could use allocated mamba cache instead. def alloc( @@ -266,14 +276,14 @@ def alloc( mamba_index = [] for req in reqs: - rid = req.rid - if rid in self.rid_to_mamba_index_mapping: - mid = self.rid_to_mamba_index_mapping[rid] - elif (mid := self.mamba_pool.alloc(1)) is not None: - mid = mid[0] - self.rid_to_mamba_index_mapping[rid] = mid - self.mamba_index_to_rid_mapping[mid] = rid - mamba_index.append(mid) + mid = None + if req.mamba_pool_idx is not None: # for radix cache + mid = req.mamba_pool_idx + else: + mid = self.mamba_pool.alloc(1) + req.mamba_pool_idx = mid + if mid is not None: + mamba_index.append(mid) assert len(select_index) == len( mamba_index ), f"Not enough space for mamba cache, try to increase --max-mamba-cache-size." @@ -294,17 +304,12 @@ def get_mamba_params_all_layers(self): # For chunk prefill, we can not free mamba cache, we need use it in the future def free(self, free_index: Union[int, List[int]], free_mamba_cache: bool = True): + if isinstance(free_index, (int,)): + free_index = [free_index] super().free(free_index) if free_mamba_cache: mamba_index = self.req_index_to_mamba_index_mapping[free_index] - mamba_index_list = mamba_index.tolist() - if isinstance(mamba_index_list, int): - mamba_index_list = [mamba_index_list] - self.mamba_pool.free(mamba_index_list) - for mid in mamba_index_list: - rid = self.mamba_index_to_rid_mapping[mid] - self.mamba_index_to_rid_mapping.pop(mid) - self.rid_to_mamba_index_mapping.pop(rid) + self.mamba_pool.free(mamba_index) def clear(self): super().clear() diff --git a/python/sglang/srt/model_executor/model_runner.py b/python/sglang/srt/model_executor/model_runner.py index d053e2bb81b..f8cc6795e43 100644 --- a/python/sglang/srt/model_executor/model_runner.py +++ b/python/sglang/srt/model_executor/model_runner.py @@ -345,8 +345,6 @@ def initialize(self, min_per_gpu_memory: float): self.is_hybrid = self.model_config.is_hybrid = True if self.is_hybrid_gdn: - logger.warning("Hybrid GDN model detected, disable radix cache") - self.server_args.disable_radix_cache = True self.server_args.attention_backend = "hybrid_linear_attn" if self.server_args.max_mamba_cache_size is None: if self.server_args.max_running_requests is not None: @@ -1423,7 +1421,8 @@ def init_memory_pool( 4096, ) if self.is_hybrid_gdn: - max_num_reqs = min(max_num_reqs, self.server_args.max_mamba_cache_size) + # for mamba cache radix, it need be divided by 4 (magic number now). (yizhang2077) + max_num_reqs = min(max_num_reqs, self.server_args.max_mamba_cache_size // 4) if self.spec_algorithm.is_eagle() or self.spec_algorithm.is_standalone(): if self.is_draft_worker: @@ -1513,6 +1512,7 @@ def init_memory_pool( ) = config.hybrid_gdn_params self.req_to_token_pool = HybridReqToTokenPool( size=max_num_reqs, + mamba_size=self.server_args.max_mamba_cache_size, max_context_len=self.model_config.context_len + extra_max_context_len, device=self.device, From 0710061c54f5027436ad71d527acf553b448c9f5 Mon Sep 17 00:00:00 2001 From: yizhang2077 <1109276519@qq.com> Date: Tue, 30 Sep 2025 01:00:06 +0800 Subject: [PATCH 02/14] tinyfix for kvindices in cache_unfinished_req --- .../sglang/srt/mem_cache/mamba_radix_cache.py | 73 +++++++++---------- 1 file changed, 33 insertions(+), 40 deletions(-) diff --git a/python/sglang/srt/mem_cache/mamba_radix_cache.py b/python/sglang/srt/mem_cache/mamba_radix_cache.py index e1ca6d7da4a..c4c342ed552 100644 --- a/python/sglang/srt/mem_cache/mamba_radix_cache.py +++ b/python/sglang/srt/mem_cache/mamba_radix_cache.py @@ -480,6 +480,9 @@ def cache_unfinished_req(self, req: Req, chunked=False) -> None: kv_indices = self.req_to_token_pool.req_to_token[ req.req_pool_idx, : len(token_ids) ] + page_aligned_len = len(kv_indices) + page_aligned_kv_indices = kv_indices.to(dtype=torch.int64, copy=True) + page_aligned_token_ids = token_ids[:page_aligned_len] # Radix Cache takes one ref in memory pool # Note: the insert function already frees the overlapped kv_indices @@ -497,8 +500,8 @@ def cache_unfinished_req(self, req: Req, chunked=False) -> None: ) assert mamba_value_forked is not None, "Can not alloc mamba cache" new_prefix_len, mamba_exist = self.insert( - RadixKey(token_ids, req.extra_key), - kv_indices, + RadixKey(page_aligned_token_ids, req.extra_key), + page_aligned_kv_indices, mamba_value_forked, ) self.token_to_kv_pool_allocator.free( @@ -540,6 +543,29 @@ def pretty_print(self) -> None: def total_size(self) -> Tuple[int, int]: return self._total_size_helper() + def _evict_leaf_node(self, x: TreeNode) -> Tuple[int, int, TreeNode, TreeNode]: + assert ( + x.full_lock_ref == 0 + ), f"leaf node with full lock must also have mamba lock, {x.id=} {x.full_lock_ref=}" + # 1. a leaf node, free full tokens and mamba + self.token_to_kv_pool_allocator.free(x.value) + full_num_evicted = len(x.value) + self.req_to_token_pool.mamba_pool.free(x.mamba_value) + mamba_num_evicted = len(x.mamba_value) + + # 2. get the next node, update the lru lists + x_next = self.mamba_lru_list.get_prev_no_lock(x) + self.full_lru_list.remove_node(x) + self.mamba_lru_list.remove_node(x) + + # 3. delete the leaf node + self._delete_leaf(x) + + # 4. Iteratively delete tombstone leaves to maintain invariant that leaf nodes are not tombstone + x, leaf_full_num_evicted = self._iteratively_delete_tombstone_leaf(x) + full_num_evicted += leaf_full_num_evicted + return full_num_evicted, mamba_num_evicted, x, x_next + def evict_mamba(self, mamba_num: int) -> None: if self.disable or mamba_num <= 0: return @@ -564,24 +590,8 @@ def evict_mamba(self, mamba_num: int) -> None: # 3. tombstone the node self._tombstone_internal_node(x) else: - assert ( - x.full_lock_ref == 0 - ), f"leaf node with full lock must also have mamba lock, {x.id=} {x.full_lock_ref=}" - # 1. a leaf node, free full and mamba tokens - self.token_to_kv_pool_allocator.free(x.value) - self.req_to_token_pool.mamba_pool.free(x.mamba_value) - mamba_num_evicted += len(x.mamba_value) - - # 2. get the next node, update the lru lists - x_next = self.mamba_lru_list.get_prev_no_lock(x) - self.full_lru_list.remove_node(x) - self.mamba_lru_list.remove_node(x) - - # 3. delete the leaf node - self._delete_leaf(x) - - # 4. Iteratively delete tombstone leaves to maintain invariant that leaf nodes are not tombstone - self._iteratively_delete_tombstone_leaf(x) + _, mamba_evicted_delta, _, x_next = self._evict_leaf_node(x) + mamba_num_evicted += mamba_evicted_delta x = x_next @@ -597,27 +607,10 @@ def evict(self, full_num_tokens: int) -> None: assert ( x != self.root_node ), f"root node should not exist in full lru list, {x.id=}" - assert x.full_lock_ref == 0, f"node is in use, {x.id=}" - - # 1. free node kv indices, evict full and mamba tokens - self.token_to_kv_pool_allocator.free(x.value) - full_num_evicted += len(x.value) - self.req_to_token_pool.mamba_pool.free(x.mamba_value) - mamba_num_evicted += len(x.mamba_value) - - # 2. get the next leaf, update the lru lists - x_next = self.full_lru_list.get_prev_leaf_no_lock(x) - self.full_lru_list.remove_node(x) - self.mamba_lru_list.remove_node(x) - - # 3. delete the leaf node - self._delete_leaf(x) - - # 4. Iteratively delete tombstone leaves to maintain invariant that leaf nodes are not tombstone - x, leaf_full_num_evicted = self._iteratively_delete_tombstone_leaf(x) - full_num_evicted += leaf_full_num_evicted + full_num_evicted_delta, _, x, x_next = self._evict_leaf_node(x) + full_num_evicted += full_num_evicted_delta - # 5. if parent has no more children, it is a leaf. It is possible that this node is lru, so + # if parent has no more children, it is a leaf. It is possible that this node is lru, so # we need to get the first leaf node in the lru list if len(x.parent.children) == 0: x_next = self.full_lru_list.get_leaf_lru_no_lock() From a141effda11048a4bd98bfda6024811097191ccd Mon Sep 17 00:00:00 2001 From: yizhang2077 <1109276519@qq.com> Date: Fri, 3 Oct 2025 17:26:12 +0000 Subject: [PATCH 03/14] bugfix mem leak --- .../sglang/srt/mem_cache/mamba_radix_cache.py | 27 ++++++++++++------- 1 file changed, 17 insertions(+), 10 deletions(-) diff --git a/python/sglang/srt/mem_cache/mamba_radix_cache.py b/python/sglang/srt/mem_cache/mamba_radix_cache.py index c4c342ed552..e4610e4562f 100644 --- a/python/sglang/srt/mem_cache/mamba_radix_cache.py +++ b/python/sglang/srt/mem_cache/mamba_radix_cache.py @@ -385,17 +385,24 @@ def match_prefix(self, key: RadixKey, **kwargs) -> MatchResult: # copy mamba state to req local space if cow is true if cow_mamba and last_node.mamba_value is not None: assert req.req_pool_idx is None # req_pool_idx is uninitialed - dst_index = self.req_to_token_pool.mamba_pool.alloc(1) - # try to alloc again, protect last_node from eviction - if dst_index is None: - self.inc_lock_ref(last_node) - self.evict_mamba(1) + + if req.mamba_pool_idx is None: # for reqs without mamba cache dst_index = self.req_to_token_pool.mamba_pool.alloc(1) - self.dec_lock_ref(last_node) - assert dst_index is not None, "Can not alloc mamba cache" - src_index = last_node.mamba_value - self.req_to_token_pool.mamba_pool.copy_from(src_index, dst_index) - req.mamba_pool_idx = dst_index[0] + # try to alloc again, protect last_node from eviction + if dst_index is None: + self.inc_lock_ref(last_node) + self.evict_mamba(1) + dst_index = self.req_to_token_pool.mamba_pool.alloc(1) + self.dec_lock_ref(last_node) + assert dst_index is not None, "Can not alloc mamba cache" + src_index = last_node.mamba_value + self.req_to_token_pool.mamba_pool.copy_from(src_index, dst_index) + req.mamba_pool_idx = dst_index[0] + else: + src_index = last_node.mamba_value + dst_index = req.mamba_pool_idx.unsqueeze(0) + self.req_to_token_pool.mamba_pool.copy_from(src_index, dst_index) + if value: value = torch.cat(value) else: From 39cca32127467b06d4005e006780f55830e741b1 Mon Sep 17 00:00:00 2001 From: yizhang2077 <1109276519@qq.com> Date: Sat, 4 Oct 2025 09:39:01 +0000 Subject: [PATCH 04/14] add ut, fix tiny bug, optimize cache_finished_req --- python/sglang/srt/managers/schedule_policy.py | 6 +- .../sglang/srt/mem_cache/mamba_radix_cache.py | 29 +- python/sglang/srt/mem_cache/memory_pool.py | 2 +- .../sglang/srt/model_executor/model_runner.py | 4 +- test/srt/test_mamba_unittest.py | 324 ++++++++++++++++++ 5 files changed, 340 insertions(+), 25 deletions(-) create mode 100644 test/srt/test_mamba_unittest.py diff --git a/python/sglang/srt/managers/schedule_policy.py b/python/sglang/srt/managers/schedule_policy.py index 6a229f60f4d..26e4eace8d1 100644 --- a/python/sglang/srt/managers/schedule_policy.py +++ b/python/sglang/srt/managers/schedule_policy.py @@ -354,7 +354,7 @@ def __init__( self.is_hybrid = isinstance( self.token_to_kv_pool_allocator, SWATokenToKVPoolAllocator ) - self.is_hybrid_gdn = isinstance(self.tree_cache, MambaRadixCache) + self.is_hybrid_gdn_cache = isinstance(self.tree_cache, MambaRadixCache) self.priority_scheduling_preemption_threshold = ( priority_scheduling_preemption_threshold @@ -378,7 +378,7 @@ def rem_total_tokens(self): self.token_to_kv_pool_allocator.swa_available_size() + self.tree_cache.swa_evictable_size(), ) - elif self.is_hybrid_gdn: + elif self.is_hybrid_gdn_cache: available_and_evictable = ( self.token_to_kv_pool_allocator.available_size() + self.tree_cache.full_evictable_size() @@ -400,7 +400,7 @@ def cur_rem_tokens(self): self.token_to_kv_pool_allocator.swa_available_size() + self.tree_cache.swa_evictable_size(), ) - elif self.is_hybrid_gdn: + elif self.is_hybrid_gdn_cache: available_and_evictable = ( self.token_to_kv_pool_allocator.available_size() + self.tree_cache.full_evictable_size() diff --git a/python/sglang/srt/mem_cache/mamba_radix_cache.py b/python/sglang/srt/mem_cache/mamba_radix_cache.py index e4610e4562f..97152dd9324 100644 --- a/python/sglang/srt/mem_cache/mamba_radix_cache.py +++ b/python/sglang/srt/mem_cache/mamba_radix_cache.py @@ -386,7 +386,8 @@ def match_prefix(self, key: RadixKey, **kwargs) -> MatchResult: if cow_mamba and last_node.mamba_value is not None: assert req.req_pool_idx is None # req_pool_idx is uninitialed - if req.mamba_pool_idx is None: # for reqs without mamba cache + # for reqs without mamba cache + if req.mamba_pool_idx is None: dst_index = self.req_to_token_pool.mamba_pool.alloc(1) # try to alloc again, protect last_node from eviction if dst_index is None: @@ -444,33 +445,23 @@ def cache_finished_req(self, req: Req) -> None: # Radix Cache takes one ref in memory pool # insert the token_ids and kv_indices into the radix tree # Note: the insert function already frees the overlapped kv_indices - mamba_value = self.req_to_token_pool.get_mamba_indices( - req.req_pool_idx - ).unsqueeze(-1) - # radix tree mamba value is forked from req space - mamba_value_forked = self.req_to_token_pool.mamba_pool.fork_from(mamba_value) - # if alloc mamba cache failed, do evict and alloc again - if mamba_value_forked is None: - self.evict_mamba(1) - mamba_value_forked = self.req_to_token_pool.mamba_pool.fork_from( - mamba_value - ) - assert mamba_value_forked is not None, "Can not alloc mamba cache" + mamba_value = ( + self.req_to_token_pool.get_mamba_indices(req.req_pool_idx) + .unsqueeze(-1) + .clone() + ) new_prefix_len, mamba_exist = self.insert( RadixKey(token_ids[:page_aligned_len], req.extra_key), page_aligned_kv_indices, - mamba_value_forked, + mamba_value, ) self.token_to_kv_pool_allocator.free( kv_indices[len(req.prefix_indices) : new_prefix_len] ) - self.req_to_token_pool.free(req.req_pool_idx) - # there is a mamba cache in radix cache, release it - if mamba_exist: - self.req_to_token_pool.mamba_pool.free(mamba_value_forked) + self.req_to_token_pool.free(req.req_pool_idx, free_mamba_cache=mamba_exist) self.dec_lock_ref(req.last_node) def cache_unfinished_req(self, req: Req, chunked=False) -> None: @@ -787,7 +778,7 @@ def _split_node(self, key: RadixKey, child: TreeNode, split_len: int) -> TreeNod new_node = TreeNode() new_node.children = {self.get_child_key_fn(key[split_len:]): child} new_node.parent = child.parent - new_node.mamba_value = None # mamba cache can not be splitted + new_node.mamba_value = None # mamba cache can not be split new_node.full_lock_ref = child.full_lock_ref new_node.mamba_lock_ref = 0 new_node.key = child.key[:split_len] diff --git a/python/sglang/srt/mem_cache/memory_pool.py b/python/sglang/srt/mem_cache/memory_pool.py index 93eec3a6bcd..c8ab99a928d 100644 --- a/python/sglang/srt/mem_cache/memory_pool.py +++ b/python/sglang/srt/mem_cache/memory_pool.py @@ -280,7 +280,7 @@ def alloc( if req.mamba_pool_idx is not None: # for radix cache mid = req.mamba_pool_idx else: - mid = self.mamba_pool.alloc(1) + mid = self.mamba_pool.alloc(1)[0] req.mamba_pool_idx = mid if mid is not None: mamba_index.append(mid) diff --git a/python/sglang/srt/model_executor/model_runner.py b/python/sglang/srt/model_executor/model_runner.py index f8cc6795e43..34e132125a7 100644 --- a/python/sglang/srt/model_executor/model_runner.py +++ b/python/sglang/srt/model_executor/model_runner.py @@ -1421,8 +1421,8 @@ def init_memory_pool( 4096, ) if self.is_hybrid_gdn: - # for mamba cache radix, it need be divided by 4 (magic number now). (yizhang2077) - max_num_reqs = min(max_num_reqs, self.server_args.max_mamba_cache_size // 4) + # for mamba cache radix, it need be divided by 3 (magic number now). (yizhang2077) + max_num_reqs = min(max_num_reqs, self.server_args.max_mamba_cache_size // 3) if self.spec_algorithm.is_eagle() or self.spec_algorithm.is_standalone(): if self.is_draft_worker: diff --git a/test/srt/test_mamba_unittest.py b/test/srt/test_mamba_unittest.py new file mode 100644 index 00000000000..3eeb05fdac0 --- /dev/null +++ b/test/srt/test_mamba_unittest.py @@ -0,0 +1,324 @@ +import inspect +import unittest + +import torch + +from sglang.srt.managers.schedule_batch import Req +from sglang.srt.mem_cache.allocator import TokenToKVPoolAllocator +from sglang.srt.mem_cache.mamba_radix_cache import MambaRadixCache +from sglang.srt.mem_cache.memory_pool import HybridLinearKVPool, HybridReqToTokenPool +from sglang.srt.mem_cache.radix_cache import RadixKey +from sglang.srt.sampling.sampling_params import SamplingParams + + +class TestMamba(unittest.TestCase): + @classmethod + def setUpClass(cls): + pass + + @classmethod + def tearDownClass(cls): + pass + + def test_hybrid_linear_kv_pool(self): + size = 16 + head_num = 2 + head_dim = 256 + num_layers = 48 + global_interval = 4 + dtype = torch.bfloat16 + device = "cuda" + full_attention_layer_ids = [ + i for i in range(global_interval - 1, num_layers, global_interval) + ] + pool = HybridLinearKVPool( + size=size, + dtype=dtype, + page_size=1, + head_num=head_num, + head_dim=head_dim, + full_attention_layer_ids=full_attention_layer_ids, + enable_kvcache_transpose=False, + device=device, + ) + assert pool._transfer_full_attention_id(global_interval - 1) == 0 + assert pool._transfer_full_attention_id(2 * global_interval - 1) == 1 + with self.assertRaises(ValueError) as context: + pool._transfer_full_attention_id(1) + self.assertIn( + "layer_id=1 not in full attention layers:", str(context.exception) + ) + + def test_mamba_pool(self): + max_num_reqs = 10 + mamba_cache_size = 20 + max_context_len = 128 + device = "cuda" + conv_dtype = torch.bfloat16 + ssm_dtype = torch.bfloat16 + global_interval = 4 + num_layers = 48 + full_attention_layer_ids = [ + i for i in range(global_interval - 1, num_layers, global_interval) + ] + mamba_layers = [ + i for i in range(num_layers) if i not in full_attention_layer_ids + ] + conv_state_shape = (8192, 4) + mamba_state_shape = (32, 128, 128) + + req_to_token_pool = HybridReqToTokenPool( + size=max_num_reqs, + mamba_size=mamba_cache_size, + max_context_len=max_context_len, + device=device, + enable_memory_saver=False, + conv_state_shape=conv_state_shape, + temporal_state_shape=mamba_state_shape, + conv_dtype=conv_dtype, + ssm_dtype=ssm_dtype, + mamba_layers=mamba_layers, + speculative_num_draft_tokens=3, + ) + + assert req_to_token_pool.available_size() == max_num_reqs + assert req_to_token_pool.mamba_pool.available_size() == mamba_cache_size + + sampling_params = SamplingParams( + temperature=0, + max_new_tokens=1, + ) + req = Req( + rid=0, + origin_input_text="", + origin_input_ids=[], + sampling_params=sampling_params, + ) + + # alloc req + req_index = req_to_token_pool.alloc(1, [req]) + assert req_to_token_pool.available_size() == max_num_reqs - 1 + assert req_to_token_pool.mamba_pool.available_size() == mamba_cache_size - 1 + + # free req + req_to_token_pool.free(req_index) + assert req_to_token_pool.available_size() == max_num_reqs + assert req_to_token_pool.mamba_pool.available_size() == mamba_cache_size + + # alloc req without free mamba cache + req.mamba_pool_idx = None + req_index = req_to_token_pool.alloc(1, [req]) + req_to_token_pool.free(req_index, free_mamba_cache=False) + assert req_to_token_pool.available_size() == max_num_reqs + assert req_to_token_pool.mamba_pool.available_size() == mamba_cache_size - 1 + + # alloc again + req_index = req_to_token_pool.alloc(1, [req]) + assert req_to_token_pool.available_size() == max_num_reqs - 1 + assert req_to_token_pool.mamba_pool.available_size() == mamba_cache_size - 1 + + def test_mamba_radix_cache_1(self): + # kv cache + size = 128 + dtype = torch.bfloat16 + head_num = 2 + head_dim = 256 + num_layers = 48 + global_interval = 4 + max_num_reqs = 10 + mamba_cache_size = 20 + max_context_len = 128 + device = "cuda" + full_attention_layer_ids = [ + i for i in range(global_interval - 1, num_layers, global_interval) + ] + + # mamba + conv_dtype = torch.bfloat16 + ssm_dtype = torch.bfloat16 + mamba_layers = [ + i for i in range(num_layers) if i not in full_attention_layer_ids + ] + conv_state_shape = (8192, 4) + mamba_state_shape = (32, 128, 128) + # setup req to token pool + req_to_token_pool = HybridReqToTokenPool( + size=max_num_reqs, + mamba_size=mamba_cache_size, + max_context_len=max_context_len, + device=device, + enable_memory_saver=False, + conv_state_shape=conv_state_shape, + temporal_state_shape=mamba_state_shape, + conv_dtype=conv_dtype, + ssm_dtype=ssm_dtype, + mamba_layers=mamba_layers, + speculative_num_draft_tokens=3, + ) + # setup kv pool + pool = HybridLinearKVPool( + size=size, + dtype=dtype, + page_size=1, + head_num=head_num, + head_dim=head_dim, + full_attention_layer_ids=full_attention_layer_ids, + enable_kvcache_transpose=False, + device=device, + ) + + # setup token to kv pool allocator + allocator = TokenToKVPoolAllocator( + size=size, + dtype=dtype, + device=device, + kvcache=pool, + need_sort=False, + ) + # setup radix cache + tree = MambaRadixCache( + req_to_token_pool=req_to_token_pool, + token_to_kv_pool_allocator=allocator, + page_size=1, + disable=False, + ) + + def make_dummy_req(): + sampling_params = SamplingParams( + temperature=0, + max_new_tokens=1, + ) + req = Req( + rid=0, + origin_input_text="", + origin_input_ids=[], + sampling_params=sampling_params, + ) + req_to_token_pool.alloc(1, reqs=[req]) + return req + + mamba_pool = req_to_token_pool.mamba_pool + # test + print( + f"[Start] allocator mamba available size: {mamba_pool.available_size()}, full available size: {allocator.available_size()}" + ) + req1 = make_dummy_req() + req1_token_ids, req1_kv_indices = [1, 2, 3], allocator.alloc(3) + assert len(req1_token_ids) == len(req1_kv_indices) + print( + f"req1: inserting, req1_token_ids: {req1_token_ids}, req1_kv_indices: {req1_kv_indices}" + ) + prefix_len = tree.insert( + RadixKey(req1_token_ids), req1_kv_indices, req1.mamba_pool_idx.unsqueeze(0) + ) + print( + f"req1: prefix_len: {prefix_len}, allocator mamba available size: {mamba_pool.available_size()}, full available size: {allocator.available_size()}" + ) + req2 = make_dummy_req() + req2_token_ids, req2_kv_indices = [1, 2, 3, 4, 5, 6, 7], allocator.alloc(7) + assert len(req2_token_ids) == len(req2_kv_indices) + print( + f"req2: inserting, req2_token_ids: {req2_token_ids}, req2_kv_indices: {req2_kv_indices}" + ) + prefix_len = tree.insert( + RadixKey(req2_token_ids), req2_kv_indices, req2.mamba_pool_idx.unsqueeze(0) + ) + print( + f"req2: prefix_len: {prefix_len}, allocator mamba available size: {mamba_pool.available_size()}, full available size: {allocator.available_size()}" + ) + + req3 = make_dummy_req() + req3_token_ids, req3_kv_indices = [10, 11, 12], allocator.alloc(3) + assert len(req3_token_ids) == len(req3_kv_indices) + print( + f"req3: inserting, req3_token_ids: {req3_token_ids}, req3_kv_indices: {req3_kv_indices}" + ) + prefix_len = tree.insert( + RadixKey(req3_token_ids), req3_kv_indices, req3.mamba_pool_idx.unsqueeze(0) + ) + print( + f"req3: prefix_len: {prefix_len}, allocator mamba available size: {mamba_pool.available_size()}, full available size: {allocator.available_size()}" + ) + req4 = make_dummy_req() + req4_token_ids, req4_kv_indices = [1, 2, 3, 4, 5, 60, 70], allocator.alloc(7) + assert len(req4_token_ids) == len(req4_kv_indices) + print( + f"req4: inserting, req4_token_ids: {req4_token_ids}, req4_kv_indices: {req4_kv_indices}" + ) + prefix_len = tree.insert( + RadixKey(req4_token_ids), req4_kv_indices, req4.mamba_pool_idx.unsqueeze(0) + ) + print( + f"req4: prefix_len: {prefix_len}, allocator mamba available size: {mamba_pool.available_size()}, full available size: {allocator.available_size()}" + ) + + tree.pretty_print() + full_num_tokens = 1 + print(f"evicting {full_num_tokens} full token") + tree.evict(full_num_tokens=full_num_tokens) + tree.pretty_print() + + mamba_num = 1 + print(f"evicting {mamba_num} mamba") + tree.evict_mamba(mamba_num=mamba_num) + tree.pretty_print() + + req5_token_ids = [1, 2, 3, 4, 5] + result = tree.match_prefix(RadixKey(req5_token_ids)) + kv_indices, last_node = result.device_indices, result.last_device_node + print( + f"req5: token_ids: {req5_token_ids}, matched kv_indices: {kv_indices}, last_node.key: {last_node.key}" + ) + assert len(kv_indices) == 0 + + req6_token_ids = [1, 2, 3, 4, 5, 60, 70] + result = tree.match_prefix(RadixKey(req6_token_ids)) + kv_indices, last_node = result.device_indices, result.last_device_node + print( + f"req6: token_ids: {req6_token_ids}, matched kv_indices: {kv_indices}, last_node.key: {last_node.key}" + ) + assert len(kv_indices) == 7 + assert len(last_node.key) == 2 + + req7_token_ids = [1, 2, 3, 4, 5, 6, 7] + result = tree.match_prefix(RadixKey(req7_token_ids)) + kv_indices, last_node = result.device_indices, result.last_device_node + print( + f"req7: token_ids: {req7_token_ids}, matched kv_indices: {kv_indices}, last_node.key: {last_node.key}" + ) + assert len(kv_indices) == 7 + assert len(last_node.key) == 2 + + mamba_num = 1 + print(f"evicting {mamba_num} mamba") + tree.evict_mamba(mamba_num=mamba_num) + tree.pretty_print() + + req8_token_ids = [1, 2, 3, 4, 5, 60, 70] + result = tree.match_prefix(RadixKey(req8_token_ids)) + kv_indices, last_node = result.device_indices, result.last_device_node + print( + f"req8: token_ids: {req8_token_ids}, matched kv_indices: {kv_indices}, last_node.key: {last_node.key}" + ) + assert len(kv_indices) == 0 + assert len(last_node.key) == 0 + + req9_token_ids = [1, 2, 3, 4, 5, 6, 7] + req9 = make_dummy_req() + result = tree.match_prefix( + RadixKey(req9_token_ids), **({"req": req9, "cow_mamba": True}) + ) + kv_indices, last_node = result.device_indices, result.last_device_node + assert req9.mamba_pool_idx is not None + assert torch.all( + mamba_pool.mamba_cache[0][:, req9.mamba_pool_idx] + == mamba_pool.mamba_cache[0][:, last_node.mamba_value] + ) + assert torch.all( + mamba_pool.mamba_cache[1][:, req9.mamba_pool_idx] + == mamba_pool.mamba_cache[1][:, last_node.mamba_value] + ) + + +if __name__ == "__main__": + unittest.main() From ef796ddd7ed51a1ec266b9b2562bf5a7d6349bf2 Mon Sep 17 00:00:00 2001 From: yizhang2077 <1109276519@qq.com> Date: Sat, 4 Oct 2025 12:29:00 +0000 Subject: [PATCH 05/14] tiny Co-authored-by: hanming-lu Co-authored-by: hzh0425 Co-authored-by: thalahors --- test/srt/run_suite.py | 1 + 1 file changed, 1 insertion(+) diff --git a/test/srt/run_suite.py b/test/srt/run_suite.py index 684faf7a875..fab5b67d63d 100644 --- a/test/srt/run_suite.py +++ b/test/srt/run_suite.py @@ -83,6 +83,7 @@ class TestFile: TestFile("test_io_struct.py", 8), TestFile("test_jinja_template_utils.py", 1), TestFile("test_logprobs.py", 55), + TestFile("test_mamba_unittest.py", 4), TestFile("test_metrics.py", 32), TestFile("test_metrics_utils.py", 1), TestFile("test_mla.py", 167), From 716a8be92c2b49cc8d02bf05e4a2218ac89998e0 Mon Sep 17 00:00:00 2001 From: yizhang2077 <1109276519@qq.com> Date: Sun, 5 Oct 2025 13:04:42 +0000 Subject: [PATCH 06/14] fix metrics --- python/sglang/srt/managers/scheduler_metrics_mixin.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/python/sglang/srt/managers/scheduler_metrics_mixin.py b/python/sglang/srt/managers/scheduler_metrics_mixin.py index 53077afbe0f..edbd3f95c02 100644 --- a/python/sglang/srt/managers/scheduler_metrics_mixin.py +++ b/python/sglang/srt/managers/scheduler_metrics_mixin.py @@ -117,7 +117,7 @@ def log_prefill_stats( ) = self._get_mamba_token_info() num_used = full_num_used token_usage = full_token_usage - token_msg = ( + token_usage_msg = ( f"full token usage: {full_token_usage:.2f}, " f"mamba usage: {mamba_usage:.2f}, " ) @@ -233,7 +233,7 @@ def log_decode_stats( ) = self._get_mamba_token_info() num_used = full_num_used token_usage = full_token_usage - token_msg = ( + token_usage_msg = ( f"#full token: {full_num_used}, " f"full token usage: {full_token_usage:.2f}, " f"mamba num: {mamba_used}, " From 047ab9a17718acb63d9985cb4f0a4d6dab40dc6b Mon Sep 17 00:00:00 2001 From: yizhang2077 <1109276519@qq.com> Date: Sun, 5 Oct 2025 16:48:17 +0000 Subject: [PATCH 07/14] bugfix --- python/sglang/srt/mem_cache/mamba_radix_cache.py | 13 +++++++++---- 1 file changed, 9 insertions(+), 4 deletions(-) diff --git a/python/sglang/srt/mem_cache/mamba_radix_cache.py b/python/sglang/srt/mem_cache/mamba_radix_cache.py index 97152dd9324..2455cc27770 100644 --- a/python/sglang/srt/mem_cache/mamba_radix_cache.py +++ b/python/sglang/srt/mem_cache/mamba_radix_cache.py @@ -541,7 +541,9 @@ def pretty_print(self) -> None: def total_size(self) -> Tuple[int, int]: return self._total_size_helper() - def _evict_leaf_node(self, x: TreeNode) -> Tuple[int, int, TreeNode, TreeNode]: + def _evict_leaf_node( + self, x: TreeNode, is_evict_mamba: bool + ) -> Tuple[int, int, TreeNode, TreeNode]: assert ( x.full_lock_ref == 0 ), f"leaf node with full lock must also have mamba lock, {x.id=} {x.full_lock_ref=}" @@ -552,7 +554,10 @@ def _evict_leaf_node(self, x: TreeNode) -> Tuple[int, int, TreeNode, TreeNode]: mamba_num_evicted = len(x.mamba_value) # 2. get the next node, update the lru lists - x_next = self.mamba_lru_list.get_prev_no_lock(x) + if is_evict_mamba: + x_next = self.mamba_lru_list.get_prev_no_lock(x) + else: + x_next = self.full_lru_list.get_prev_leaf_no_lock(x) self.full_lru_list.remove_node(x) self.mamba_lru_list.remove_node(x) @@ -588,7 +593,7 @@ def evict_mamba(self, mamba_num: int) -> None: # 3. tombstone the node self._tombstone_internal_node(x) else: - _, mamba_evicted_delta, _, x_next = self._evict_leaf_node(x) + _, mamba_evicted_delta, _, x_next = self._evict_leaf_node(x, True) mamba_num_evicted += mamba_evicted_delta x = x_next @@ -605,7 +610,7 @@ def evict(self, full_num_tokens: int) -> None: assert ( x != self.root_node ), f"root node should not exist in full lru list, {x.id=}" - full_num_evicted_delta, _, x, x_next = self._evict_leaf_node(x) + full_num_evicted_delta, _, x, x_next = self._evict_leaf_node(x, False) full_num_evicted += full_num_evicted_delta # if parent has no more children, it is a leaf. It is possible that this node is lru, so From 67d4e34913ae1992219d41f08e7eee826385e098 Mon Sep 17 00:00:00 2001 From: yizhang2077 <1109276519@qq.com> Date: Sun, 5 Oct 2025 17:17:46 +0000 Subject: [PATCH 08/14] tiny --- python/sglang/srt/managers/schedule_batch.py | 1 + python/sglang/srt/mem_cache/mamba_radix_cache.py | 2 +- python/sglang/srt/mem_cache/memory_pool.py | 2 +- 3 files changed, 3 insertions(+), 2 deletions(-) diff --git a/python/sglang/srt/managers/schedule_batch.py b/python/sglang/srt/managers/schedule_batch.py index b4c611a0f07..80a621855c3 100644 --- a/python/sglang/srt/managers/schedule_batch.py +++ b/python/sglang/srt/managers/schedule_batch.py @@ -820,6 +820,7 @@ def reset_for_retract(self): self.extend_logprob_start_len = 0 self.is_chunked = 0 self.req_pool_idx = None + self.mamba_pool_idx = None self.already_computed = 0 def offload_kv_cache(self, req_to_token_pool, token_to_kv_pool_allocator): diff --git a/python/sglang/srt/mem_cache/mamba_radix_cache.py b/python/sglang/srt/mem_cache/mamba_radix_cache.py index 2455cc27770..7ca3167c1ef 100644 --- a/python/sglang/srt/mem_cache/mamba_radix_cache.py +++ b/python/sglang/srt/mem_cache/mamba_radix_cache.py @@ -515,7 +515,7 @@ def cache_unfinished_req(self, req: Req, chunked=False) -> None: ) if not mamba_exist: - assert new_last_node.mamba_value == mamba_value_forked + assert torch.equal(new_last_node.mamba_value, mamba_value_forked) assert len(req.prefix_indices) <= len( new_indices ), f"{req.prefix_indices=}, {new_indices=}" diff --git a/python/sglang/srt/mem_cache/memory_pool.py b/python/sglang/srt/mem_cache/memory_pool.py index c8ab99a928d..fe129eae7ca 100644 --- a/python/sglang/srt/mem_cache/memory_pool.py +++ b/python/sglang/srt/mem_cache/memory_pool.py @@ -216,7 +216,7 @@ def copy_from(self, src_index: torch.Tensor, dst_index: torch.Tensor): self.mamba_cache[1][:, dst_index] = self.mamba_cache[1][:, src_index] return - def fork_from(self, src_index: torch.Tensor) -> Optional[int]: + def fork_from(self, src_index: torch.Tensor) -> Optional[torch.Tensor]: dst_index = self.alloc(1) if dst_index == None: return None From bb5e87604e574886f1830734a58372fff07e80ec Mon Sep 17 00:00:00 2001 From: yizhang2077 <1109276519@qq.com> Date: Wed, 8 Oct 2025 15:25:46 +0000 Subject: [PATCH 09/14] add more assertion, add MAMBA_CACHE_SIZE_MAX_RUNNING_REQUESTS_RATIO constant --- .../sglang/srt/mem_cache/mamba_radix_cache.py | 30 +++++++++---------- .../sglang/srt/model_executor/model_runner.py | 11 +++++-- 2 files changed, 22 insertions(+), 19 deletions(-) diff --git a/python/sglang/srt/mem_cache/mamba_radix_cache.py b/python/sglang/srt/mem_cache/mamba_radix_cache.py index 7ca3167c1ef..b656a85b92a 100644 --- a/python/sglang/srt/mem_cache/mamba_radix_cache.py +++ b/python/sglang/srt/mem_cache/mamba_radix_cache.py @@ -482,8 +482,6 @@ def cache_unfinished_req(self, req: Req, chunked=False) -> None: page_aligned_kv_indices = kv_indices.to(dtype=torch.int64, copy=True) page_aligned_token_ids = token_ids[:page_aligned_len] - # Radix Cache takes one ref in memory pool - # Note: the insert function already frees the overlapped kv_indices mamba_value = self.req_to_token_pool.get_mamba_indices( req.req_pool_idx ).unsqueeze(-1) @@ -511,11 +509,9 @@ def cache_unfinished_req(self, req: Req, chunked=False) -> None: # The prefix indices could be updated, reuse it new_indices, new_last_node, _, _ = self.match_prefix( - RadixKey(token_ids, req.extra_key) + RadixKey(page_aligned_token_ids, req.extra_key) ) - if not mamba_exist: - assert torch.equal(new_last_node.mamba_value, mamba_value_forked) assert len(req.prefix_indices) <= len( new_indices ), f"{req.prefix_indices=}, {new_indices=}" @@ -545,8 +541,10 @@ def _evict_leaf_node( self, x: TreeNode, is_evict_mamba: bool ) -> Tuple[int, int, TreeNode, TreeNode]: assert ( - x.full_lock_ref == 0 - ), f"leaf node with full lock must also have mamba lock, {x.id=} {x.full_lock_ref=}" + x.full_lock_ref == 0 and x.mamba_lock_ref == 0 + ), f"evict leaf node invalid with {x.id=} {x.full_lock_ref=} {x.mamba_lock_ref=}" + + assert x.mamba_value is not None, f"leaf node mamba value is not None, {x.id=}" # 1. a leaf node, free full tokens and mamba self.token_to_kv_pool_allocator.free(x.value) full_num_evicted = len(x.value) @@ -578,6 +576,9 @@ def evict_mamba(self, mamba_num: int) -> None: # evict lru leaf nodes until mamba_num_tokens is reached while mamba_num_evicted < mamba_num and (self.mamba_lru_list.in_list(x)): assert x.mamba_value is not None, f"node has no mamba value, {x.id=}" + assert ( + len(x.mamba_value) == 1 + ), f"node has abnormal mamba length, {x.id=}, {len(x.mamba_value)=}" assert x != self.root_node, f"root node is not evictable, {x.id=}" assert x.mamba_lock_ref == 0, f"node is in use by mamba kv indices, {x.id=}" @@ -813,6 +814,7 @@ def _insert_helper( ) -> Tuple[int, bool]: # Update the last access time from root to leaf, so that # mamba will tombstone the node closer to root first + assert mamba_value is not None, "Mamba value should not be None here." node.last_access_time = time.monotonic() if node != self.root_node: self.full_lru_list.reset_node_mru(node) @@ -847,15 +849,13 @@ def _insert_helper( new_node.mamba_value = mamba_value self.full_lru_list.insert_mru(new_node) self.full_evictable_size_ += len(value) - if mamba_value is not None: - self.mamba_evictable_size_ += len(mamba_value) - self.mamba_lru_list.insert_mru(new_node) + self.mamba_evictable_size_ += len(mamba_value) + self.mamba_lru_list.insert_mru(new_node) node.children[child_key] = new_node elif node.mamba_value is None: # add for mamba tombstone node.mamba_value = mamba_value - if mamba_value is not None: - self.mamba_evictable_size_ += len(mamba_value) - self.mamba_lru_list.insert_mru(node) + self.mamba_evictable_size_ += len(mamba_value) + self.mamba_lru_list.insert_mru(node) else: mamba_value_exist = True @@ -898,9 +898,7 @@ def _delete_leaf(self, node: TreeNode) -> None: def _tombstone_internal_node(self, node: TreeNode) -> None: assert len(node.children) != 0, f"Cannot tombstone a leaf node, {node.id=}" - self.mamba_evictable_size_ -= ( - len(node.mamba_value) if node.mamba_value is not None else 0 - ) + self.mamba_evictable_size_ -= len(node.mamba_value) node.mamba_value = None def _delete_tombstone_leaf(self, node: TreeNode) -> None: diff --git a/python/sglang/srt/model_executor/model_runner.py b/python/sglang/srt/model_executor/model_runner.py index b2d6458899f..5aca7637727 100644 --- a/python/sglang/srt/model_executor/model_runner.py +++ b/python/sglang/srt/model_executor/model_runner.py @@ -175,6 +175,9 @@ def add_mla_attention_backend(backend_name): # Detect stragger ranks in model loading UNBALANCED_MODEL_LOADING_TIMEOUT_S = 300 +# the ratio of mamba cache pool size to max_running_requests, it will be safe when it is larger than 2 (yizhang2077) +MAMBA_CACHE_SIZE_MAX_RUNNING_REQUESTS_RATIO = 3 + logger = logging.getLogger(__name__) @@ -351,9 +354,14 @@ def initialize(self, min_per_gpu_memory: float): if self.server_args.max_running_requests is not None: self.server_args.max_mamba_cache_size = ( self.server_args.max_running_requests + * MAMBA_CACHE_SIZE_MAX_RUNNING_REQUESTS_RATIO ) else: self.server_args.max_mamba_cache_size = 512 + self.server_args.max_running_requests = ( + self.server_args.max_mamba_cache_size + // MAMBA_CACHE_SIZE_MAX_RUNNING_REQUESTS_RATIO + ) self.server_args.max_mamba_cache_size = ( self.server_args.max_mamba_cache_size // ( @@ -1427,9 +1435,6 @@ def init_memory_pool( ), 4096, ) - if self.is_hybrid_gdn: - # for mamba cache radix, it need be divided by 3 (magic number now). (yizhang2077) - max_num_reqs = min(max_num_reqs, self.server_args.max_mamba_cache_size // 3) if self.spec_algorithm.is_eagle() or self.spec_algorithm.is_standalone(): if self.is_draft_worker: From 13758f6e713429af6efe9d50dbc9b50bcb14c3ec Mon Sep 17 00:00:00 2001 From: yizhang2077 <1109276519@qq.com> Date: Wed, 8 Oct 2025 16:08:15 +0000 Subject: [PATCH 10/14] tiny --- python/sglang/srt/mem_cache/mamba_radix_cache.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/python/sglang/srt/mem_cache/mamba_radix_cache.py b/python/sglang/srt/mem_cache/mamba_radix_cache.py index b656a85b92a..9e7e9b0f82f 100644 --- a/python/sglang/srt/mem_cache/mamba_radix_cache.py +++ b/python/sglang/srt/mem_cache/mamba_radix_cache.py @@ -512,6 +512,9 @@ def cache_unfinished_req(self, req: Req, chunked=False) -> None: RadixKey(page_aligned_token_ids, req.extra_key) ) + if not mamba_exist: + assert torch.equal(new_last_node.mamba_value, mamba_value_forked) + assert len(req.prefix_indices) <= len( new_indices ), f"{req.prefix_indices=}, {new_indices=}" From e01ed7fa9234af84c0578c230e0a38aa0807a999 Mon Sep 17 00:00:00 2001 From: yizhang2077 <1109276519@qq.com> Date: Wed, 8 Oct 2025 17:59:30 +0000 Subject: [PATCH 11/14] resolve conflicts --- python/sglang/srt/managers/scheduler.py | 4 +++- python/sglang/srt/mem_cache/memory_pool.py | 6 ++++-- python/sglang/srt/model_executor/model_runner.py | 2 +- 3 files changed, 8 insertions(+), 4 deletions(-) diff --git a/python/sglang/srt/managers/scheduler.py b/python/sglang/srt/managers/scheduler.py index 000a9d3b358..9d07c78c0b0 100644 --- a/python/sglang/srt/managers/scheduler.py +++ b/python/sglang/srt/managers/scheduler.py @@ -491,7 +491,9 @@ def __init__( # Hybrid memory pool self.is_hybrid = self.tp_worker.is_hybrid - self.is_hybrid_gdn = self.tp_worker.worker.model_runner.is_hybrid_gdn + self.is_hybrid_gdn = ( + self.tp_worker.worker.model_runner.hybrid_gdn_config is not None + ) if self.is_hybrid: self.sliding_window_size = self.tp_worker.sliding_window_size diff --git a/python/sglang/srt/mem_cache/memory_pool.py b/python/sglang/srt/mem_cache/memory_pool.py index b82862c7359..870b9c01219 100644 --- a/python/sglang/srt/mem_cache/memory_pool.py +++ b/python/sglang/srt/mem_cache/memory_pool.py @@ -238,8 +238,10 @@ def clear(self): self.free_slots = torch.arange(self.size, dtype=torch.int64, device=self.device) def copy_from(self, src_index: torch.Tensor, dst_index: torch.Tensor): - self.mamba_cache[0][:, dst_index] = self.mamba_cache[0][:, src_index] - self.mamba_cache[1][:, dst_index] = self.mamba_cache[1][:, src_index] + self.mamba_cache.conv[:, dst_index] = self.mamba_cache.conv[:, src_index] + self.mamba_cache.temporal[:, dst_index] = self.mamba_cache.temporal[ + :, src_index + ] return def fork_from(self, src_index: torch.Tensor) -> Optional[torch.Tensor]: diff --git a/python/sglang/srt/model_executor/model_runner.py b/python/sglang/srt/model_executor/model_runner.py index 6df49ae34d9..4ccc202f76b 100644 --- a/python/sglang/srt/model_executor/model_runner.py +++ b/python/sglang/srt/model_executor/model_runner.py @@ -371,7 +371,7 @@ def initialize(self, min_per_gpu_memory: float): self.server_args.max_mamba_cache_size // MAMBA_CACHE_SIZE_MAX_RUNNING_REQUESTS_RATIO ) - + if self.hybrid_gdn_config is not None: self.server_args.max_mamba_cache_size = ( self.server_args.max_mamba_cache_size From 867c6545904b1f8159f5fe908b91ef2719595c3d Mon Sep 17 00:00:00 2001 From: Hanming Lu <69857889+hanming-lu@users.noreply.github.com> Date: Wed, 8 Oct 2025 23:51:48 -0700 Subject: [PATCH 12/14] [GDN] Mamba radix cache ratio support (#11347) --- python/sglang/srt/mem_cache/memory_pool.py | 2 + .../sglang/srt/model_executor/model_runner.py | 81 ++++++++++++------- python/sglang/srt/server_args.py | 7 ++ 3 files changed, 60 insertions(+), 30 deletions(-) diff --git a/python/sglang/srt/mem_cache/memory_pool.py b/python/sglang/srt/mem_cache/memory_pool.py index 870b9c01219..ba11089976b 100644 --- a/python/sglang/srt/mem_cache/memory_pool.py +++ b/python/sglang/srt/mem_cache/memory_pool.py @@ -190,6 +190,7 @@ def __init__( ) logger.info( f"Mamba Cache is allocated. " + f"max_mamba_cache_size: {size}, " f"conv_state size: {get_tensor_size_bytes(conv_state) / GB:.2f}GB, " f"ssm_state size: {get_tensor_size_bytes(temporal_state) / GB:.2f}GB " f"intermediate_ssm_state_cache size: {get_tensor_size_bytes(intermediate_ssm_state_cache) / GB:.2f}GB " @@ -199,6 +200,7 @@ def __init__( self.mamba_cache = self.State(conv=conv_state, temporal=temporal_state) logger.info( f"Mamba Cache is allocated. " + f"max_mamba_cache_size: {size}, " f"conv_state size: {get_tensor_size_bytes(conv_state) / GB:.2f}GB, " f"ssm_state size: {get_tensor_size_bytes(temporal_state) / GB:.2f}GB " ) diff --git a/python/sglang/srt/model_executor/model_runner.py b/python/sglang/srt/model_executor/model_runner.py index 4ccc202f76b..a33cb76094a 100644 --- a/python/sglang/srt/model_executor/model_runner.py +++ b/python/sglang/srt/model_executor/model_runner.py @@ -358,30 +358,6 @@ def initialize(self, min_per_gpu_memory: float): if architectures and not any("Llama4" in arch for arch in architectures): self.is_hybrid = self.model_config.is_hybrid = True - if config := self.mambaish_config: - if self.server_args.max_mamba_cache_size is None: - if self.server_args.max_running_requests is not None: - self.server_args.max_mamba_cache_size = ( - self.server_args.max_running_requests - * MAMBA_CACHE_SIZE_MAX_RUNNING_REQUESTS_RATIO - ) - else: - self.server_args.max_mamba_cache_size = 512 - self.server_args.max_running_requests = ( - self.server_args.max_mamba_cache_size - // MAMBA_CACHE_SIZE_MAX_RUNNING_REQUESTS_RATIO - ) - - if self.hybrid_gdn_config is not None: - self.server_args.max_mamba_cache_size = ( - self.server_args.max_mamba_cache_size - // ( - self.server_args.dp_size - if self.server_args.enable_dp_attention - else 1 - ) - ) - # For MTP models like DeepSeek-V3 or GLM-4.5, the MTP layer(s) are used separately as draft # models for speculative decoding. In those cases, `num_nextn_predict_layers` is used to # determine the number of layers. @@ -1297,15 +1273,50 @@ def profile_max_num_token(self, total_gpu_memory: int): rest_memory = available_gpu_memory - total_gpu_memory * ( 1 - self.mem_fraction_static ) - if config := self.mambaish_config: - rest_memory -= ( - self.server_args.max_mamba_cache_size - * config.mamba2_cache_params.mamba_cache_per_req - / (1 << 30) - ) + if self.mambaish_config is not None: + rest_memory = self.handle_max_mamba_cache(rest_memory) max_num_token = int(rest_memory * (1 << 30) // cell_size) return max_num_token + def handle_max_mamba_cache(self, total_rest_memory): + config = self.mambaish_config + server_args = self.server_args + assert config is not None + + if server_args.disable_radix_cache: + # with disable radix cache, sets the max_mamba_cache_size based on the max_running_requests + if server_args.max_mamba_cache_size is None: + if server_args.max_running_requests is not None: + server_args.max_mamba_cache_size = server_args.max_running_requests + else: + server_args.max_mamba_cache_size = 512 + else: + # allocate the memory based on the ratio between mamba state memory vs. full kv cache memory + # solve the equations: + # 1. mamba_state_memory + full_kv_cache_memory == total_rest_memory + # 2. mamba_state_memory / full_kv_cache_memory == server_args.mamba_full_memory_ratio + mamba_state_memory_raw = ( + total_rest_memory + * server_args.mamba_full_memory_ratio + / (1 + server_args.mamba_full_memory_ratio) + ) + # calculate the max_mamba_cache_size based on the given total mamba memory + server_args.max_mamba_cache_size = int( + (mamba_state_memory_raw * (1 << 30)) + // config.mamba2_cache_params.mamba_cache_per_req + ) + + if self.hybrid_gdn_config is not None: + server_args.max_mamba_cache_size = server_args.max_mamba_cache_size // ( + server_args.dp_size if server_args.enable_dp_attention else 1 + ) + mamba_state_memory = ( + server_args.max_mamba_cache_size + * config.mamba2_cache_params.mamba_cache_per_req + / (1 << 30) + ) + return total_rest_memory - mamba_state_memory + @property def hybrid_gdn_config(self): config = self.model_config.hf_config @@ -1458,6 +1469,16 @@ def init_memory_pool( 4096, ) + if self.mambaish_config is not None: + ratio = ( + MAMBA_CACHE_SIZE_MAX_RUNNING_REQUESTS_RATIO + if not self.server_args.disable_radix_cache + else 1 + ) + max_num_reqs = min( + max_num_reqs, self.server_args.max_mamba_cache_size // ratio + ) + if self.spec_algorithm.is_eagle() or self.spec_algorithm.is_standalone(): if self.is_draft_worker: self.max_total_num_tokens = self.server_args.draft_runner_cache_size diff --git a/python/sglang/srt/server_args.py b/python/sglang/srt/server_args.py index 52cbd038f67..9c2f8f23c66 100644 --- a/python/sglang/srt/server_args.py +++ b/python/sglang/srt/server_args.py @@ -353,6 +353,7 @@ class ServerArgs: # Mamba cache max_mamba_cache_size: Optional[int] = None mamba_ssm_dtype: str = "float32" + mamba_full_memory_ratio: float = 0.2 # Hierarchical cache enable_hierarchical_cache: bool = False @@ -2334,6 +2335,12 @@ def add_cli_args(parser: argparse.ArgumentParser): choices=["float32", "bfloat16"], help="The data type of the SSM states in mamba cache.", ) + parser.add_argument( + "--mamba-full-memory-ratio", + type=float, + default=ServerArgs.mamba_full_memory_ratio, + help="The ratio of mamba state memory to full kv cache memory.", + ) # Hierarchical cache parser.add_argument( From 48db1ca5d2383020eb75861314c4d2d0164f48c7 Mon Sep 17 00:00:00 2001 From: yizhang2077 <1109276519@qq.com> Date: Fri, 10 Oct 2025 06:02:58 +0000 Subject: [PATCH 13/14] fix ut, open sanity check --- python/sglang/srt/managers/scheduler.py | 4 +- .../sglang/srt/mem_cache/mamba_radix_cache.py | 38 +++++++------ test/srt/test_mamba_unittest.py | 54 +++++++++++-------- 3 files changed, 57 insertions(+), 39 deletions(-) diff --git a/python/sglang/srt/managers/scheduler.py b/python/sglang/srt/managers/scheduler.py index a03fa73e68a..1580319a942 100644 --- a/python/sglang/srt/managers/scheduler.py +++ b/python/sglang/srt/managers/scheduler.py @@ -1779,7 +1779,9 @@ def check_memory(self): self._publish_kv_events() def check_tree_cache(self): - if self.is_hybrid and isinstance(self.tree_cache, SWARadixCache): + if (self.is_hybrid and isinstance(self.tree_cache, SWARadixCache)) or ( + self.is_hybrid_gdn and isinstance(self.tree_cache, MambaRadixCache) + ): self.tree_cache.sanity_check() def _get_token_info(self): diff --git a/python/sglang/srt/mem_cache/mamba_radix_cache.py b/python/sglang/srt/mem_cache/mamba_radix_cache.py index 9e7e9b0f82f..7467daa5d56 100644 --- a/python/sglang/srt/mem_cache/mamba_radix_cache.py +++ b/python/sglang/srt/mem_cache/mamba_radix_cache.py @@ -158,15 +158,15 @@ def reset_node_and_parents_mru(self, node, root_node): Move an (existing) node and its parents to most recently used position. Child node is more recently used than parent node. """ - assert not self.mamba prev_node = self.head while node != root_node: - assert ( - node.id in self.cache - ), f"Resetting node {node.id=} not in lru list when resetting node and parents mru" - self._remove_node(node) - self._add_node_after(prev_node, node) - prev_node = node + if not self.mamba or node.mamba_value is not None: + assert ( + node.id in self.cache + ), f"Resetting node {node.id=} not in lru list when resetting node and parents mru" + self._remove_node(node) + self._add_node_after(prev_node, node) + prev_node = node node = node.parent def insert_mru(self, node): @@ -273,13 +273,13 @@ def sanity_check(self, tree_cache: "MambaRadixCache"): else: nodes = tree_cache._collect_all_nodes() total_nodes = len(nodes) - total_lru_plus_1 = len(self.cache) + 1 + total_lru = len(self.cache) # heapify based on last_access_time heapq.heapify(nodes) # the root node is not in the lru list - assert ( - len(nodes) == len(self.cache) + 1 - ), f"len(nodes): {len(nodes)} != len(self.cache) + 1: {len(self.cache) + 1}" + assert len(nodes) == ( + total_lru + (0 if self.mamba else 1) + ), f"len(nodes): {len(nodes)}, total_lru: {total_lru}" x_lru = self._get_lru() while len(nodes): @@ -307,7 +307,7 @@ def sanity_check(self, tree_cache: "MambaRadixCache"): assert ( evictable_size == lru_list_evictable_size - ), f"{self.mamba=}, total nodes: {total_nodes}, total lru plus 1: {total_lru_plus_1}, evictable size: {evictable_size} != lru list evictable size: {lru_list_evictable_size}" + ), f"{self.mamba=}, total nodes: {total_nodes}, total lru: {total_lru}, evictable size: {evictable_size} != lru list evictable size: {lru_list_evictable_size}" except Exception as e: msg = f"Mamba Radix tree sanity check failed, ping @yizhang2077: {e}" logger.error(msg) @@ -440,7 +440,7 @@ def cache_finished_req(self, req: Req) -> None: ] page_aligned_len = len(kv_indices) - page_aligned_kv_indices = kv_indices.clone() + page_aligned_kv_indices = kv_indices.to(dtype=torch.int64, copy=True) # Radix Cache takes one ref in memory pool # insert the token_ids and kv_indices into the radix tree @@ -770,8 +770,7 @@ def _match_prefix_helper( # update time for matched nodes, and make nodes closer to root to be least recently used # this allows mamba to evict nodes closer to root first self.full_lru_list.reset_node_and_parents_mru(best_last_node, self.root_node) - if best_last_node.mamba_value is not None: - self.mamba_lru_list.reset_node_mru(best_last_node) + self.mamba_lru_list.reset_node_and_parents_mru(best_last_node, self.root_node) # This last_access_time is for sanity check, can be deleted after validation in production cur_time = time.monotonic() @@ -797,6 +796,8 @@ def _split_node(self, key: RadixKey, child: TreeNode, split_len: int) -> TreeNod child.last_access_time = time.monotonic() self.full_lru_list.remove_node(child) + if child.mamba_value is not None: + self.mamba_lru_list.remove_node(child) child.parent = new_node child.key = child.key[split_len:] child.value = child.value[split_len:] @@ -806,6 +807,8 @@ def _split_node(self, key: RadixKey, child: TreeNode, split_len: int) -> TreeNod # parent first so that parent is after child in the lru list self.full_lru_list.insert_mru(new_node) self.full_lru_list.insert_mru(child) + if child.mamba_value is not None: + self.mamba_lru_list.insert_mru(child) return new_node def _insert_helper( @@ -821,6 +824,8 @@ def _insert_helper( node.last_access_time = time.monotonic() if node != self.root_node: self.full_lru_list.reset_node_mru(node) + if node.mamba_value is not None: + self.mamba_lru_list.reset_node_mru(node) if len(key) == 0: return 0, True @@ -831,6 +836,8 @@ def _insert_helper( node = node.children[child_key] node.last_access_time = time.monotonic() self.full_lru_list.reset_node_mru(node) + if node.mamba_value is not None: + self.mamba_lru_list.reset_node_mru(node) prefix_len = self.key_match_fn(node.key, key) total_prefix_length += prefix_len key = key[prefix_len:] @@ -861,6 +868,7 @@ def _insert_helper( self.mamba_lru_list.insert_mru(node) else: mamba_value_exist = True + self.mamba_lru_list.reset_node_mru(node) return total_prefix_length, mamba_value_exist diff --git a/test/srt/test_mamba_unittest.py b/test/srt/test_mamba_unittest.py index 3eeb05fdac0..401eb584f96 100644 --- a/test/srt/test_mamba_unittest.py +++ b/test/srt/test_mamba_unittest.py @@ -1,8 +1,10 @@ import inspect +import os import unittest import torch +from sglang.srt.configs.mamba_utils import Mamba2CacheParams, Mamba2StateShape from sglang.srt.managers.schedule_batch import Req from sglang.srt.mem_cache.allocator import TokenToKVPoolAllocator from sglang.srt.mem_cache.mamba_radix_cache import MambaRadixCache @@ -54,8 +56,6 @@ def test_mamba_pool(self): mamba_cache_size = 20 max_context_len = 128 device = "cuda" - conv_dtype = torch.bfloat16 - ssm_dtype = torch.bfloat16 global_interval = 4 num_layers = 48 full_attention_layer_ids = [ @@ -64,8 +64,17 @@ def test_mamba_pool(self): mamba_layers = [ i for i in range(num_layers) if i not in full_attention_layer_ids ] - conv_state_shape = (8192, 4) - mamba_state_shape = (32, 128, 128) + shape = Mamba2StateShape.create( + tp_world_size=1, + intermediate_size=4096, + n_groups=16, + num_heads=32, + head_dim=128, + state_size=128, + conv_kernel=4, + ) + os.environ["SGLANG_MAMBA_SSM_DTYPE"] = "bfloat16" + mamba2_cache_params = Mamba2CacheParams(shape=shape, layers=mamba_layers) req_to_token_pool = HybridReqToTokenPool( size=max_num_reqs, @@ -73,11 +82,7 @@ def test_mamba_pool(self): max_context_len=max_context_len, device=device, enable_memory_saver=False, - conv_state_shape=conv_state_shape, - temporal_state_shape=mamba_state_shape, - conv_dtype=conv_dtype, - ssm_dtype=ssm_dtype, - mamba_layers=mamba_layers, + cache_params=mamba2_cache_params, speculative_num_draft_tokens=3, ) @@ -134,25 +139,28 @@ def test_mamba_radix_cache_1(self): ] # mamba - conv_dtype = torch.bfloat16 - ssm_dtype = torch.bfloat16 mamba_layers = [ i for i in range(num_layers) if i not in full_attention_layer_ids ] - conv_state_shape = (8192, 4) - mamba_state_shape = (32, 128, 128) - # setup req to token pool + os.environ["SGLANG_MAMBA_SSM_DTYPE"] = "bfloat16" + shape = Mamba2StateShape.create( + tp_world_size=1, + intermediate_size=4096, + n_groups=16, + num_heads=32, + head_dim=128, + state_size=128, + conv_kernel=4, + ) + mamba2_cache_params = Mamba2CacheParams(shape=shape, layers=mamba_layers) + req_to_token_pool = HybridReqToTokenPool( size=max_num_reqs, mamba_size=mamba_cache_size, max_context_len=max_context_len, device=device, enable_memory_saver=False, - conv_state_shape=conv_state_shape, - temporal_state_shape=mamba_state_shape, - conv_dtype=conv_dtype, - ssm_dtype=ssm_dtype, - mamba_layers=mamba_layers, + cache_params=mamba2_cache_params, speculative_num_draft_tokens=3, ) # setup kv pool @@ -311,12 +319,12 @@ def make_dummy_req(): kv_indices, last_node = result.device_indices, result.last_device_node assert req9.mamba_pool_idx is not None assert torch.all( - mamba_pool.mamba_cache[0][:, req9.mamba_pool_idx] - == mamba_pool.mamba_cache[0][:, last_node.mamba_value] + mamba_pool.mamba_cache.conv[:, req9.mamba_pool_idx] + == mamba_pool.mamba_cache.conv[:, last_node.mamba_value] ) assert torch.all( - mamba_pool.mamba_cache[1][:, req9.mamba_pool_idx] - == mamba_pool.mamba_cache[1][:, last_node.mamba_value] + mamba_pool.mamba_cache.temporal[:, req9.mamba_pool_idx] + == mamba_pool.mamba_cache.temporal[:, last_node.mamba_value] ) From 8022cdfd2c1514142778eaf50ef930c57ab20877 Mon Sep 17 00:00:00 2001 From: yizhang2077 <1109276519@qq.com> Date: Fri, 10 Oct 2025 06:44:58 +0000 Subject: [PATCH 14/14] fix mtp initial bug --- python/sglang/srt/model_executor/model_runner.py | 12 +++++++++++- 1 file changed, 11 insertions(+), 1 deletion(-) diff --git a/python/sglang/srt/model_executor/model_runner.py b/python/sglang/srt/model_executor/model_runner.py index cd887364da6..63a252e5941 100644 --- a/python/sglang/srt/model_executor/model_runner.py +++ b/python/sglang/srt/model_executor/model_runner.py @@ -1294,7 +1294,15 @@ def handle_max_mamba_cache(self, total_rest_memory): server_args = self.server_args assert config is not None - if server_args.disable_radix_cache: + speculativa_ratio = ( + 0 + if server_args.speculative_num_draft_tokens is None + else server_args.speculative_num_draft_tokens + ) + if ( + server_args.disable_radix_cache + or config.mamba2_cache_params.mamba_cache_per_req == 0 + ): # with disable radix cache, sets the max_mamba_cache_size based on the max_running_requests if server_args.max_mamba_cache_size is None: if server_args.max_running_requests is not None: @@ -1315,6 +1323,7 @@ def handle_max_mamba_cache(self, total_rest_memory): server_args.max_mamba_cache_size = int( (mamba_state_memory_raw * (1 << 30)) // config.mamba2_cache_params.mamba_cache_per_req + // (1 + speculativa_ratio) ) if self.hybrid_gdn_config is not None: @@ -1324,6 +1333,7 @@ def handle_max_mamba_cache(self, total_rest_memory): mamba_state_memory = ( server_args.max_mamba_cache_size * config.mamba2_cache_params.mamba_cache_per_req + * (1 + speculativa_ratio) / (1 << 30) ) return total_rest_memory - mamba_state_memory