Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
25 commits
Select commit Hold shift + click to select a range
4de95c3
init support mamba radix cache
yizhang2077 Sep 27, 2025
0710061
tinyfix for kvindices in cache_unfinished_req
yizhang2077 Sep 29, 2025
a141eff
bugfix mem leak
yizhang2077 Oct 3, 2025
39cca32
add ut, fix tiny bug, optimize cache_finished_req
yizhang2077 Oct 4, 2025
ef796dd
tiny
yizhang2077 Oct 4, 2025
b753f80
Merge branch 'main' into support_mamba_radix_cache
yizhang2077 Oct 4, 2025
716a8be
fix metrics
yizhang2077 Oct 5, 2025
047ab9a
bugfix
yizhang2077 Oct 5, 2025
67d4e34
tiny
yizhang2077 Oct 5, 2025
bb5e876
add more assertion, add MAMBA_CACHE_SIZE_MAX_RUNNING_REQUESTS_RATIO c…
yizhang2077 Oct 8, 2025
e7eaa5f
Merge branch 'main' into support_mamba_radix_cache
yizhang2077 Oct 8, 2025
13758f6
tiny
yizhang2077 Oct 8, 2025
66afa78
Merge branch 'main' into support_mamba_radix_cache
yizhang2077 Oct 8, 2025
e01ed7f
resolve conflicts
yizhang2077 Oct 8, 2025
ae4703b
Merge branch 'main' into support_mamba_radix_cache
zhyncs Oct 8, 2025
867c654
[GDN] Mamba radix cache ratio support (#11347)
hanming-lu Oct 9, 2025
0251832
Merge branch 'main' into support_mamba_radix_cache
yizhang2077 Oct 9, 2025
eba82ba
Merge branch 'main' into support_mamba_radix_cache
yizhang2077 Oct 9, 2025
48db1ca
fix ut, open sanity check
yizhang2077 Oct 10, 2025
8022cdf
fix mtp initial bug
yizhang2077 Oct 10, 2025
d2d1b35
Merge branch 'main' into support_mamba_radix_cache
yizhang2077 Oct 10, 2025
f577e47
Merge branch 'main' into support_mamba_radix_cache
yizhang2077 Oct 10, 2025
7f83210
Merge branch 'main' into support_mamba_radix_cache
yizhang2077 Oct 12, 2025
c4273af
disable other mambarish models radix cache
yizhang2077 Oct 12, 2025
f22a940
Merge branch 'main' into support_mamba_radix_cache
zhyncs Oct 12, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 31 additions & 2 deletions python/sglang/srt/managers/schedule_batch.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,8 @@
alloc_for_extend,
alloc_token_slots,
)
from sglang.srt.mem_cache.memory_pool import ReqToTokenPool
from sglang.srt.mem_cache.mamba_radix_cache import MambaRadixCache
from sglang.srt.mem_cache.memory_pool import HybridReqToTokenPool, ReqToTokenPool
from sglang.srt.mem_cache.radix_cache import RadixKey
from sglang.srt.mem_cache.swa_radix_cache import SWARadixCache
from sglang.srt.metrics.collector import SchedulerMetricsCollector, TimeStats
Expand Down Expand Up @@ -481,6 +482,7 @@ def __init__(

# Memory pool info
self.req_pool_idx: Optional[int] = None
self.mamba_pool_idx: Optional[torch.Tensor] = None # shape (1)

# Check finish
self.tokenizer = None
Expand Down Expand Up @@ -683,7 +685,12 @@ def init_next_round_input(self, tree_cache: Optional[BasePrefixCache] = None):
self.last_host_node,
self.host_hit_length,
) = tree_cache.match_prefix(
key=RadixKey(token_ids=token_ids, extra_key=self.extra_key)
key=RadixKey(token_ids=token_ids, extra_key=self.extra_key),
**(
{"req": self, "cow_mamba": True}
if isinstance(tree_cache, MambaRadixCache)
else {}
),
)
self.last_matched_prefix_len = len(self.prefix_indices)
self.extend_input_len = len(self.fill_ids) - len(self.prefix_indices)
Expand Down Expand Up @@ -833,6 +840,7 @@ def reset_for_retract(self):
self.extend_logprob_start_len = 0
self.is_chunked = 0
self.req_pool_idx = None
self.mamba_pool_idx = None
self.already_computed = 0

def offload_kv_cache(self, req_to_token_pool, token_to_kv_pool_allocator):
Expand Down Expand Up @@ -1027,6 +1035,27 @@ def batch_size(self):
def is_empty(self):
return len(self.reqs) == 0

def alloc_req_slots(self, num_reqs: int, reqs: Optional[List[Req]] = None):
if isinstance(self.req_to_token_pool, HybridReqToTokenPool):
mamba_available_size = self.req_to_token_pool.mamba_pool.available_size()
if mamba_available_size < num_reqs:
if self.tree_cache is not None and isinstance(
self.tree_cache, MambaRadixCache
):
mamba_num = max(0, num_reqs - mamba_available_size)
self.tree_cache.evict_mamba(mamba_num)
req_pool_indices = self.req_to_token_pool.alloc(num_reqs, reqs)
else:
req_pool_indices = self.req_to_token_pool.alloc(num_reqs)
if req_pool_indices is None:
raise RuntimeError(
"alloc_req_slots runs out of memory. "
"Please set a smaller number for `--max-running-requests`. "
f"{self.req_to_token_pool.available_size()=}, "
f"{num_reqs=}, "
)
return req_pool_indices

def allocate_for_eagle_v2(self):
from sglang.srt.speculative.eagle_info import EagleDraftInput
from sglang.srt.speculative.spec_utils import assign_req_to_token_pool
Expand Down
12 changes: 12 additions & 0 deletions python/sglang/srt/managers/schedule_policy.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
from sglang.srt.managers.schedule_batch import Req, ScheduleBatch
from sglang.srt.mem_cache.allocator import SWATokenToKVPoolAllocator
from sglang.srt.mem_cache.base_prefix_cache import BasePrefixCache
from sglang.srt.mem_cache.mamba_radix_cache import MambaRadixCache
from sglang.srt.mem_cache.radix_cache import RadixCache, RadixKey, TreeNode
from sglang.srt.server_args import ServerArgs

Expand Down Expand Up @@ -357,6 +358,7 @@ def __init__(
self.is_hybrid = isinstance(
self.token_to_kv_pool_allocator, SWATokenToKVPoolAllocator
)
self.is_hybrid_gdn_cache = isinstance(self.tree_cache, MambaRadixCache)

self.priority_scheduling_preemption_threshold = (
priority_scheduling_preemption_threshold
Expand All @@ -380,6 +382,11 @@ def rem_total_tokens(self):
self.token_to_kv_pool_allocator.swa_available_size()
+ self.tree_cache.swa_evictable_size(),
)
elif self.is_hybrid_gdn_cache:
available_and_evictable = (
self.token_to_kv_pool_allocator.available_size()
+ self.tree_cache.full_evictable_size()
)
else:
available_and_evictable = (
self.token_to_kv_pool_allocator.available_size()
Expand All @@ -397,6 +404,11 @@ def cur_rem_tokens(self):
self.token_to_kv_pool_allocator.swa_available_size()
+ self.tree_cache.swa_evictable_size(),
)
elif self.is_hybrid_gdn_cache:
available_and_evictable = (
self.token_to_kv_pool_allocator.available_size()
+ self.tree_cache.full_evictable_size()
)
else:
available_and_evictable = (
self.token_to_kv_pool_allocator.available_size()
Expand Down
78 changes: 77 additions & 1 deletion python/sglang/srt/managers/scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,6 +145,7 @@
from sglang.srt.managers.utils import validate_input_length
from sglang.srt.mem_cache.chunk_cache import ChunkCache, SWAChunkCache
from sglang.srt.mem_cache.hiradix_cache import HiRadixCache
from sglang.srt.mem_cache.mamba_radix_cache import MambaRadixCache
from sglang.srt.mem_cache.radix_cache import RadixCache
from sglang.srt.mem_cache.swa_radix_cache import SWARadixCache
from sglang.srt.model_executor.forward_batch_info import ForwardBatch, PPProxyTensors
Expand Down Expand Up @@ -467,6 +468,10 @@ def __init__(

# Hybrid memory pool
self.is_hybrid = self.tp_worker.is_hybrid
self.is_hybrid_gdn = (
self.tp_worker.worker.model_runner.hybrid_gdn_config is not None
)

if self.is_hybrid:
self.sliding_window_size = self.tp_worker.sliding_window_size
self.full_tokens_per_layer, self.swa_tokens_per_layer = (
Expand Down Expand Up @@ -813,6 +818,16 @@ def init_memory_pool_and_cache(self):
disable=server_args.disable_radix_cache,
is_eagle=self.spec_algorithm.is_eagle(),
)
elif self.is_hybrid_gdn:
assert (
self.server_args.disaggregation_mode == "null"
), "Hybrid GDN mode does not support disaggregation yet"
self.tree_cache = MambaRadixCache(
req_to_token_pool=self.req_to_token_pool,
token_to_kv_pool_allocator=self.token_to_kv_pool_allocator,
page_size=self.page_size,
disable=server_args.disable_radix_cache,
)
elif server_args.enable_lmcache:
from sglang.srt.mem_cache.storage.lmcache.lmc_radix_cache import (
LMCRadixCache,
Expand Down Expand Up @@ -1686,6 +1701,25 @@ def check_memory(self):
f"{self.full_tokens_per_layer=}, {full_available_size=}, {full_evictable_size=}, {self.tree_cache.full_protected_size()=}\n"
f"{self.swa_tokens_per_layer=}, {swa_available_size=}, {swa_evictable_size=}, {self.tree_cache.swa_protected_size()=}\n"
)
elif self.is_hybrid_gdn and isinstance(self.tree_cache, MambaRadixCache):
(
full_num_used,
mamba_num_used,
_,
_,
full_available_size,
full_evictable_size,
mamba_available_size,
mamba_evictable_size,
) = self._get_mamba_token_info()
memory_leak = (
full_num_used != self.tree_cache.full_protected_size()
or mamba_num_used != self.tree_cache.mamba_protected_size()
)
token_msg = (
f"{full_available_size=}, {full_evictable_size=}, {self.token_to_kv_pool_allocator.size=}, {self.tree_cache.full_protected_size()=}\n"
f"{mamba_available_size=}, {mamba_evictable_size=}, {self.req_to_token_pool.mamba_pool.size=}, {self.tree_cache.mamba_protected_size()=}\n"
)
else:
_, _, available_size, evictable_size = self._get_token_info()
protected_size = self.tree_cache.protected_size()
Expand Down Expand Up @@ -1736,6 +1770,17 @@ def check_memory(self):
) = self._get_swa_token_info()
num_used = max(full_num_used, swa_num_used)
token_usage = max(full_token_usage, swa_token_usage)
elif self.is_hybrid_gdn:
(
num_used,
_,
token_usage,
_,
_,
_,
_,
_,
) = self._get_mamba_token_info()
else:
num_used, token_usage, _, _ = self._get_token_info()
num_running_reqs = len(self.running_batch.reqs)
Expand Down Expand Up @@ -1763,7 +1808,9 @@ def check_memory(self):
self._publish_kv_events()

def check_tree_cache(self):
if self.is_hybrid and isinstance(self.tree_cache, SWARadixCache):
if (self.is_hybrid and isinstance(self.tree_cache, SWARadixCache)) or (
self.is_hybrid_gdn and isinstance(self.tree_cache, MambaRadixCache)
):
self.tree_cache.sanity_check()

def _get_token_info(self):
Expand All @@ -1773,6 +1820,35 @@ def _get_token_info(self):
token_usage = num_used / self.max_total_num_tokens
return num_used, token_usage, available_size, evictable_size

def _get_mamba_token_info(self):
is_radix_tree = isinstance(self.tree_cache, MambaRadixCache)
full_available_size = self.token_to_kv_pool_allocator.available_size()
full_evictable_size = (
self.tree_cache.full_evictable_size() if is_radix_tree else 0
)
mamba_available_size = self.req_to_token_pool.mamba_pool.available_size()
mamba_evictable_size = (
self.tree_cache.mamba_evictable_size() if is_radix_tree else 0
)
full_num_used = self.token_to_kv_pool_allocator.size - (
full_available_size + full_evictable_size
)
mamba_num_used = self.req_to_token_pool.mamba_pool.size - (
mamba_available_size + mamba_evictable_size
)
full_token_usage = full_num_used / self.token_to_kv_pool_allocator.size
mamba_usage = mamba_num_used / self.req_to_token_pool.mamba_pool.size
return (
full_num_used,
mamba_num_used,
full_token_usage,
mamba_usage,
full_available_size,
full_evictable_size,
mamba_available_size,
mamba_evictable_size,
)

def _get_swa_token_info(self):
full_available_size = self.token_to_kv_pool_allocator.full_available_size()
full_evictable_size = self.tree_cache.full_evictable_size()
Expand Down
36 changes: 36 additions & 0 deletions python/sglang/srt/managers/scheduler_metrics_mixin.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,23 @@ def log_prefill_stats(
f"full token usage: {full_token_usage:.2f}, "
f"swa token usage: {swa_token_usage:.2f}, "
)
elif self.is_hybrid_gdn:
(
full_num_used,
_,
full_token_usage,
mamba_usage,
_,
_,
_,
_,
) = self._get_mamba_token_info()
num_used = full_num_used
token_usage = full_token_usage
token_usage_msg = (
f"full token usage: {full_token_usage:.2f}, "
f"mamba usage: {mamba_usage:.2f}, "
)
else:
num_used, token_usage, _, _ = self._get_token_info()
token_usage_msg = f"token usage: {token_usage:.2f}, "
Expand Down Expand Up @@ -203,6 +220,25 @@ def log_decode_stats(
f"#swa token: {swa_num_used}, "
f"swa token usage: {swa_token_usage:.2f}, "
)
elif self.is_hybrid_gdn:
(
full_num_used,
mamba_used,
full_token_usage,
mamba_usage,
_,
_,
_,
_,
) = self._get_mamba_token_info()
num_used = full_num_used
token_usage = full_token_usage
token_usage_msg = (
f"#full token: {full_num_used}, "
f"full token usage: {full_token_usage:.2f}, "
f"mamba num: {mamba_used}, "
f"mamba usage: {mamba_usage:.2f}, "
)
else:
num_used, token_usage, _, _ = self._get_token_info()
token_usage_msg = f"#token: {num_used}, token usage: {token_usage:.2f}, "
Expand Down
Loading
Loading