Skip to content
Open
Show file tree
Hide file tree
Changes from 15 commits
Commits
Show all changes
22 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 16 additions & 1 deletion python/sglang/srt/managers/schedule_batch.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,7 @@
)
from sglang.srt.mem_cache.base_prefix_cache import BasePrefixCache
from sglang.srt.mem_cache.chunk_cache import ChunkCache, SWAChunkCache
from sglang.srt.mem_cache.mamba_radix_cache import MambaRadixCache
from sglang.srt.mem_cache.memory_pool import HybridReqToTokenPool, ReqToTokenPool
from sglang.srt.mem_cache.radix_cache import RadixKey
from sglang.srt.mem_cache.swa_radix_cache import SWARadixCache
Expand Down Expand Up @@ -505,6 +506,7 @@ def __init__(

# Memory pool info
self.req_pool_idx: Optional[int] = None
self.mamba_pool_idx: Optional[torch.Tensor] = None # shape (1)

# Check finish
self.tokenizer = None
Expand Down Expand Up @@ -708,7 +710,12 @@ def init_next_round_input(self, tree_cache: Optional[BasePrefixCache] = None):
self.last_host_node,
self.host_hit_length,
) = tree_cache.match_prefix(
key=RadixKey(token_ids=token_ids, extra_key=self.extra_key)
key=RadixKey(token_ids=token_ids, extra_key=self.extra_key),
**(
{"req": self, "cow_mamba": True}
if isinstance(tree_cache, MambaRadixCache)
else {}
),
)
self.last_matched_prefix_len = len(self.prefix_indices)
self.extend_input_len = len(self.fill_ids) - len(self.prefix_indices)
Expand Down Expand Up @@ -804,6 +811,7 @@ def reset_for_retract(self):
self.extend_logprob_start_len = 0
self.is_chunked = 0
self.req_pool_idx = None
self.mamba_pool_idx = None
self.already_computed = 0

def offload_kv_cache(self, req_to_token_pool, token_to_kv_pool_allocator):
Expand Down Expand Up @@ -1000,6 +1008,13 @@ def is_empty(self):

def alloc_req_slots(self, num_reqs: int, reqs: Optional[List[Req]] = None):
if isinstance(self.req_to_token_pool, HybridReqToTokenPool):
mamba_available_size = self.req_to_token_pool.mamba_pool.available_size()
if mamba_available_size < num_reqs:
if self.tree_cache is not None and isinstance(
self.tree_cache, MambaRadixCache
):
mamba_num = max(0, num_reqs - mamba_available_size)
self.tree_cache.evict_mamba(mamba_num)
req_pool_indices = self.req_to_token_pool.alloc(num_reqs, reqs)
else:
req_pool_indices = self.req_to_token_pool.alloc(num_reqs)
Expand Down
12 changes: 12 additions & 0 deletions python/sglang/srt/managers/schedule_policy.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
from sglang.srt.managers.schedule_batch import Req, ScheduleBatch
from sglang.srt.mem_cache.allocator import SWATokenToKVPoolAllocator
from sglang.srt.mem_cache.base_prefix_cache import BasePrefixCache
from sglang.srt.mem_cache.mamba_radix_cache import MambaRadixCache
from sglang.srt.mem_cache.radix_cache import RadixCache, RadixKey, TreeNode
from sglang.srt.server_args import ServerArgs

Expand Down Expand Up @@ -357,6 +358,7 @@ def __init__(
self.is_hybrid = isinstance(
self.token_to_kv_pool_allocator, SWATokenToKVPoolAllocator
)
self.is_hybrid_gdn_cache = isinstance(self.tree_cache, MambaRadixCache)

self.priority_scheduling_preemption_threshold = (
priority_scheduling_preemption_threshold
Expand All @@ -380,6 +382,11 @@ def rem_total_tokens(self):
self.token_to_kv_pool_allocator.swa_available_size()
+ self.tree_cache.swa_evictable_size(),
)
elif self.is_hybrid_gdn_cache:
available_and_evictable = (
self.token_to_kv_pool_allocator.available_size()
+ self.tree_cache.full_evictable_size()
)
else:
available_and_evictable = (
self.token_to_kv_pool_allocator.available_size()
Expand All @@ -397,6 +404,11 @@ def cur_rem_tokens(self):
self.token_to_kv_pool_allocator.swa_available_size()
+ self.tree_cache.swa_evictable_size(),
)
elif self.is_hybrid_gdn_cache:
available_and_evictable = (
self.token_to_kv_pool_allocator.available_size()
+ self.tree_cache.full_evictable_size()
)
else:
available_and_evictable = (
self.token_to_kv_pool_allocator.available_size()
Expand Down
74 changes: 74 additions & 0 deletions python/sglang/srt/managers/scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,6 +146,7 @@
from sglang.srt.managers.utils import validate_input_length
from sglang.srt.mem_cache.chunk_cache import ChunkCache, SWAChunkCache
from sglang.srt.mem_cache.hiradix_cache import HiRadixCache
from sglang.srt.mem_cache.mamba_radix_cache import MambaRadixCache
from sglang.srt.mem_cache.radix_cache import RadixCache
from sglang.srt.mem_cache.swa_radix_cache import SWARadixCache
from sglang.srt.model_executor.forward_batch_info import (
Expand Down Expand Up @@ -490,6 +491,10 @@ def __init__(

# Hybrid memory pool
self.is_hybrid = self.tp_worker.is_hybrid
self.is_hybrid_gdn = (
self.tp_worker.worker.model_runner.hybrid_gdn_config is not None
)

if self.is_hybrid:
self.sliding_window_size = self.tp_worker.sliding_window_size
self.full_tokens_per_layer, self.swa_tokens_per_layer = (
Expand Down Expand Up @@ -792,6 +797,16 @@ def init_memory_pool_and_cache(self):
disable=server_args.disable_radix_cache,
is_eagle=self.spec_algorithm.is_eagle(),
)
elif self.is_hybrid_gdn:
assert (
self.server_args.disaggregation_mode == "null"
), "Hybrid GDN mode does not support disaggregation yet"
self.tree_cache = MambaRadixCache(
req_to_token_pool=self.req_to_token_pool,
token_to_kv_pool_allocator=self.token_to_kv_pool_allocator,
page_size=self.page_size,
disable=server_args.disable_radix_cache,
)
elif server_args.enable_lmcache:
from sglang.srt.mem_cache.storage.lmcache.lmc_radix_cache import (
LMCRadixCache,
Expand Down Expand Up @@ -1650,6 +1665,25 @@ def check_memory(self):
f"{self.full_tokens_per_layer=}, {full_available_size=}, {full_evictable_size=}, {self.tree_cache.full_protected_size()=}\n"
f"{self.swa_tokens_per_layer=}, {swa_available_size=}, {swa_evictable_size=}, {self.tree_cache.swa_protected_size()=}\n"
)
elif self.is_hybrid_gdn and isinstance(self.tree_cache, MambaRadixCache):
(
full_num_used,
mamba_num_used,
_,
_,
full_available_size,
full_evictable_size,
mamba_available_size,
mamba_evictable_size,
) = self._get_mamba_token_info()
memory_leak = (
full_num_used != self.tree_cache.full_protected_size()
or mamba_num_used != self.tree_cache.mamba_protected_size()
)
token_msg = (
f"{full_available_size=}, {full_evictable_size=}, {self.token_to_kv_pool_allocator.size=}, {self.tree_cache.full_protected_size()=}\n"
f"{mamba_available_size=}, {mamba_evictable_size=}, {self.req_to_token_pool.mamba_pool.size=}, {self.tree_cache.mamba_protected_size()=}\n"
)
else:
_, _, available_size, evictable_size = self._get_token_info()
protected_size = self.tree_cache.protected_size()
Expand Down Expand Up @@ -1700,6 +1734,17 @@ def check_memory(self):
) = self._get_swa_token_info()
num_used = max(full_num_used, swa_num_used)
token_usage = max(full_token_usage, swa_token_usage)
elif self.is_hybrid_gdn:
(
num_used,
_,
token_usage,
_,
_,
_,
_,
_,
) = self._get_mamba_token_info()
else:
num_used, token_usage, _, _ = self._get_token_info()
num_running_reqs = len(self.running_batch.reqs)
Expand Down Expand Up @@ -1737,6 +1782,35 @@ def _get_token_info(self):
token_usage = num_used / self.max_total_num_tokens
return num_used, token_usage, available_size, evictable_size

def _get_mamba_token_info(self):
is_radix_tree = isinstance(self.tree_cache, MambaRadixCache)
full_available_size = self.token_to_kv_pool_allocator.available_size()
full_evictable_size = (
self.tree_cache.full_evictable_size() if is_radix_tree else 0
)
mamba_available_size = self.req_to_token_pool.mamba_pool.available_size()
mamba_evictable_size = (
self.tree_cache.mamba_evictable_size() if is_radix_tree else 0
)
full_num_used = self.token_to_kv_pool_allocator.size - (
full_available_size + full_evictable_size
)
mamba_num_used = self.req_to_token_pool.mamba_pool.size - (
mamba_available_size + mamba_evictable_size
)
full_token_usage = full_num_used / self.token_to_kv_pool_allocator.size
mamba_usage = mamba_num_used / self.req_to_token_pool.mamba_pool.size
return (
full_num_used,
mamba_num_used,
full_token_usage,
mamba_usage,
full_available_size,
full_evictable_size,
mamba_available_size,
mamba_evictable_size,
)

def _get_swa_token_info(self):
full_available_size = self.token_to_kv_pool_allocator.full_available_size()
full_evictable_size = self.tree_cache.full_evictable_size()
Expand Down
36 changes: 36 additions & 0 deletions python/sglang/srt/managers/scheduler_metrics_mixin.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,23 @@ def log_prefill_stats(
f"full token usage: {full_token_usage:.2f}, "
f"swa token usage: {swa_token_usage:.2f}, "
)
elif self.is_hybrid_gdn:
(
full_num_used,
_,
full_token_usage,
mamba_usage,
_,
_,
_,
_,
) = self._get_mamba_token_info()
num_used = full_num_used
token_usage = full_token_usage
token_usage_msg = (
f"full token usage: {full_token_usage:.2f}, "
f"mamba usage: {mamba_usage:.2f}, "
)
else:
num_used, token_usage, _, _ = self._get_token_info()
token_usage_msg = f"token usage: {token_usage:.2f}, "
Expand Down Expand Up @@ -203,6 +220,25 @@ def log_decode_stats(
f"#swa token: {swa_num_used}, "
f"swa token usage: {swa_token_usage:.2f}, "
)
elif self.is_hybrid_gdn:
(
full_num_used,
mamba_used,
full_token_usage,
mamba_usage,
_,
_,
_,
_,
) = self._get_mamba_token_info()
num_used = full_num_used
token_usage = full_token_usage
token_usage_msg = (
f"#full token: {full_num_used}, "
f"full token usage: {full_token_usage:.2f}, "
f"mamba num: {mamba_used}, "
f"mamba usage: {mamba_usage:.2f}, "
)
else:
num_used, token_usage, _, _ = self._get_token_info()
token_usage_msg = f"#token: {num_used}, token usage: {token_usage:.2f}, "
Expand Down
Loading
Loading