Skip to content

Conversation

yizhang2077
Copy link
Collaborator

@yizhang2077 yizhang2077 commented Oct 4, 2025

Motivation

ref #10438. add radix cache for mamba, we will implement page_size > 1 and Marconi soon

Co-authored-by: hanming-lu [email protected]
Co-authored-by: hzh0425 [email protected]
Co-authored-by: thalahors [email protected]

Modifications

ref: doc

Accuracy Tests

python3 -m sglang.launch_server --model Qwen/Qwen3-Next-80B-A3B-Instruct --tp 4  --chunked-prefill-size 64 --max-running-requests 16

python3 benchmark/gsm8k/bench_sglang.py --num-question 1000
Accuracy: 0.949
Invalid: 0.000
Latency: 346.597 s
Output throughput: 485.867 token/s

Benchmarking and Profiling

# multi-turn benchmark

python3 -m sglang.bench_serving --backend sglang --dataset-name generated-shared-prefix --gsp-num-groups 50 --gsp-prompts-per-group 10 --gsp-system-prompt-len 10240 --gsp-question-len 256 --gsp-output-len 128 --max-concurrency 5  --port 30000

# without radix cache
 python3 -m sglang.launch_server --model Qwen/Qwen3-Next-80B-A3B-Instruct --tp 4 --disable-radix-cache

============ Serving Benchmark Result ============
Backend:                                 sglang    
Traffic request rate:                    inf       
Max request concurrency:                 5         
Successful requests:                     500       
Benchmark duration (s):                  189.37    
Total input tokens:                      5521483   
Total generated tokens:                  64000     
Total generated tokens (retokenized):    63980     
Request throughput (req/s):              2.64      
Input token throughput (tok/s):          29157.63  
Output token throughput (tok/s):         337.97    
Total token throughput (tok/s):          29495.60  
Concurrency:                             5.00      
----------------End-to-End Latency----------------
Mean E2E Latency (ms):                   1891.89   
Median E2E Latency (ms):                 1884.63   
---------------Time to First Token----------------
Mean TTFT (ms):                          890.59    
Median TTFT (ms):                        1065.17   
P99 TTFT (ms):                           1180.80   
---------------Inter-Token Latency----------------
Mean ITL (ms):                           7.90      
Median ITL (ms):                         6.31      
P95 ITL (ms):                            6.57      
P99 ITL (ms):                            8.14      
Max ITL (ms):                            751.16    
==================================================

# with radix cache
 python3 -m sglang.launch_server --model Qwen/Qwen3-Next-80B-A3B-Instruct --tp 4

============ Serving Benchmark Result ============
Backend:                                 sglang    
Traffic request rate:                    inf       
Max request concurrency:                 5         
Successful requests:                     500       
Benchmark duration (s):                  141.38    
Total input tokens:                      5521483   
Total generated tokens:                  64000     
Total generated tokens (retokenized):    63985     
Request throughput (req/s):              3.54      
Input token throughput (tok/s):          39055.43  
Output token throughput (tok/s):         452.69    
Total token throughput (tok/s):          39508.12  
Concurrency:                             5.00      
----------------End-to-End Latency----------------
Mean E2E Latency (ms):                   1412.49   
Median E2E Latency (ms):                 1337.07   
---------------Time to First Token----------------
Mean TTFT (ms):                          507.33    
Median TTFT (ms):                        489.45    
P99 TTFT (ms):                           1336.39   
---------------Inter-Token Latency----------------
Mean ITL (ms):                           7.14      
Median ITL (ms):                         6.30      
P95 ITL (ms):                            6.85      
P99 ITL (ms):                            8.87      
Max ITL (ms):                            1484.16   
==================================================

Checklist

Copy link
Contributor

Summary of Changes

Hello @yizhang2077, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

This pull request introduces a foundational feature for supporting Mamba radix cache (v0) within the SGLang system. The core objective is to enhance the efficiency of KV cache management for models incorporating Mamba architectures. This is achieved by implementing a specialized radix tree that intelligently handles both standard and Mamba-specific KV states, allowing for better resource utilization and faster inference. The changes span across memory allocation, request scheduling, and cache eviction policies, culminating in significant performance gains as evidenced by the provided benchmarks.

Highlights

  • Mamba Radix Cache Implementation: Introduced a new MambaRadixCache class to efficiently manage hybrid (full and Mamba) KV cache states, leveraging a radix tree structure for optimized prefix sharing.
  • Memory Management Enhancements: Updated memory_pool.py to support Mamba-specific memory allocation, freeing, and state copying/forking using torch.Tensor for improved GPU compatibility and efficiency.
  • Scheduler and Policy Integration: Modified the scheduling and batch management logic across schedule_batch.py, schedule_policy.py, and scheduler.py to seamlessly integrate the new MambaRadixCache, including mechanisms for Mamba cache eviction and detailed memory usage tracking.
  • Performance Improvements: Benchmarking results demonstrate a notable increase in request throughput (from 2.64 req/s to 3.54 req/s) and input/output token throughput, alongside reduced end-to-end and time-to-first-token latencies when the Mamba radix cache is enabled.
  • Unit Testing: Added comprehensive unit tests in test_mamba_unittest.py to validate the functionality of HybridLinearKVPool, MambaPool, and MambaRadixCache, ensuring correctness of allocation, eviction, and prefix matching.
Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point by creating a comment using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands on the current page.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in pull request comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here.

You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces support for Mamba radix cache, which is a significant feature enhancement. The implementation is comprehensive, touching upon scheduling, memory management, and the model execution flow. The new MambaRadixCache is well-structured, and unit tests have been added. I've identified a few areas for improvement, including a potential bug in an assertion, a type hint mismatch, and the use of a magic number that should be a constant. Overall, this is a solid contribution.

if self.is_hybrid_gdn:
max_num_reqs = min(max_num_reqs, self.server_args.max_mamba_cache_size)
# for mamba cache radix, it need be divided by 3 (magic number now). (yizhang2077)
max_num_reqs = min(max_num_reqs, self.server_args.max_mamba_cache_size // 3)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The code uses a magic number 3 to divide max_mamba_cache_size. The comment acknowledges this. It's better to define this as a named constant with a clear explanation of why this division is necessary. This improves code readability and maintainability. For example: MAMBA_CACHE_REQS_RATIO = 3 could be defined at the top of the file or in a constants module.

@Swipe4057
Copy link
Contributor

You need to fix the typo and rename the token_msg variables to token_usage_msg:



class MambaRadixCache(BasePrefixCache):
def __init__(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does it compatible with MTP? EAGLE fix also should be applied to MambaRadixCache.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes, I think maybe we can do it in another PR

@Swipe4057
Copy link
Contributor

During testing, I discovered that the server crashed when token_usage reached more than 0.99.

@yizhang2077
Copy link
Collaborator Author

yizhang2077 commented Oct 5, 2025

During testing, I discovered that the server crashed when token_usage reached more than 0.99.

Do you have reproduce command? I think token_usage > 0.99 is an abnormal state. (It is too large and other models will crash as well in this state)

@Swipe4057
Copy link
Contributor

During testing, I discovered that the server crashed when token_usage reached more than 0.99.

Do you have reproduce command? I think token_usage > 0.99 is an abnormal state. (It is too large and other models will crash as well in this state)

reproduce command (server H100, tp-size=2):
python bench_serving.py
--host 127.0.0.1
--port 30000
--dataset-name generated-shared-prefix
--gsp-system-prompt-len 1000
--gsp-question-prompt-len 2744
--gsp-output-prompt-len 820
--gsp-num-groups 8
--gsp-prompts-per-group 128

The root cause is incorrect memory availability checking in the Mamba pool. Instead of checking the Mamba pool, only the MHA pool is used, which leads to attempted memory allocation in a full Mamba pool and subsequent server crash due to None being returned from mamba_pool.alloc().

Clode Sonnet's recommendations:

  1. Scheduler.check_memory():
    • Check availability of both pools (MHA and Mamba) separately.
  2. PrefillAdder.budget_state():
    • For hybrid models with Mamba, check availability of both pools separately.
  3. ScheduleBatch.alloc_token_slots():
    • For Mamba pool, use req_to_token_pool.mamba_pool.alloc() instead of token_to_kv_pool_allocator.alloc().
  4. MambaRadixCache.match_prefix():
    • Add Mamba pool availability check before allocation.
  5. MambaRadixCache.evict_mamba():
    • Add verification that eviction will free sufficient memory.
  6. Scheduler._add_request_to_queue():
    • Add Mamba pool availability check before _prefetch_kvcache().
  7. HybridReqToTokenPool.alloc():
    • Check Mamba pool availability before allocation.
  8. Add logging when Mamba pool memory is insufficient for diagnostics.

@yizhang2077
Copy link
Collaborator Author

yizhang2077 commented Oct 5, 2025

@Swipe4057 Mamba pool controls memory capability by setting available size to 3x max_running_requests here . We can control mamba_usage to around 0.66 at most during this benchmark. I have tried your benchmark in Qwen3-Next-80B-A3B-Instruct-FP8 in H100 and it did not crash (but I found another bug)

@Swipe4057
Copy link
Contributor

Swipe4057 commented Oct 5, 2025

@Swipe4057 Mamba pool controls memory capability by setting available size to 3x max_running_requests here . We can control mamba_usage to around 0.66 at most during this benchmark. I have tried your benchmark in Qwen3-Next-80B-A3B-Instruct-FP8 in H100 and it did not crash (but I found another bug)

Run the service with the command and try testing again:
environment:
- SGLANG_ENABLE_JIT_DEEPGEMM=1
command:
--model-path /data/models/Qwen3-Next-80B-A3B-Instruct
--served-model-name Qwen3-Next-80B-A3B-Instruct
--cuda-graph-max-bs 512
--cuda-graph-bs 1 2 4 8 16 24 32 40 48 56 64 72 80 88 96 104 112 120 128 136 144 152 160 168 176 184 192 200 208 216 224 232 240 248 256 264 272 280 288 296 304 312 320 328 336 344 352 360 368 376 384 392 400 408 416 424 432 440 448 456 464 472 480 488 496 504 512
--sleep-on-idle
--port 8027
--host 0.0.0.0
--schedule-policy lof
--random-seed 11111
--context-length 131072
--grammar-backend xgrammar
--tool-call-parser qwen25
--enable-metrics
--quantization w8a8_fp8
--allow-auto-truncate
--mamba-ssm-dtype bfloat16
--max-running-requests 1024
--tp-size 2
--ep-size 2
--chunked-prefill-size 16384
--prefill-attention-backend flashinfer
--decode-attention-backend flashinfer
--mem-fraction-static 0.86
--max-running-requests 1024
--api-key 123

Model: https://huggingface.co/Qwen/Qwen3-Next-80B-A3B-Instruct

@yizhang2077 yizhang2077 force-pushed the support_mamba_radix_cache branch from 72812db to 67d4e34 Compare October 5, 2025 18:15
@yizhang2077
Copy link
Collaborator Author

yizhang2077 commented Oct 5, 2025

@Swipe4057 I have tried and it is ok. Could you share your server log? (error and mamba_usage in log are important items)

@Swipe4057
Copy link
Contributor

You shouldn't limit the number of queries in operation with the new magic number 3 for all operating modes. You need to include the condition that radix cache is enabled.

if self.is_hybrid_gdn:
# for mamba cache radix, it need be divided by 3 (magic number now). (yizhang2077)
max_num_reqs = min(max_num_reqs, self.server_args.max_mamba_cache_size // 3)

@Swipe4057
Copy link
Contributor

Swipe4057 commented Oct 5, 2025

Restart no longer occurs. I tested the new code, here are my results, I wrote the server launch command earlier:

main, disable radix:
input_throughput: 17906

mr, enable radix:
input_throughput: 18279

Full:

without radix cache

{'backend': 'sglang', 'dataset_name': 'generated-shared-prefix', 'request_rate': inf, 'max_concurrency': None, 'sharegpt_output_len': None, 'random_input_len': 1024, 'random_output_len': 1024, 'random_range_ratio': 0.0, 'duration': 225.5629655458033, 'completed': 1024, 'total_input_tokens': 4039046, 'total_output_tokens': 839680, 'total_output_tokens_retokenized': 839041, 'request_throughput': 4.5397523371010315, 'input_throughput': 17906.512224764232, 'output_throughput': 3722.596916422846, 'mean_e2e_latency_ms': 131048.26526600846, 'median_e2e_latency_ms': 114622.40997888148, 'std_e2e_latency_ms': 63244.031925246876, 'p99_e2e_latency_ms': 225238.2608978264, 'mean_ttft_ms': 102325.65577756759, 'median_ttft_ms': 91108.4556542337, 'std_ttft_ms': 62789.35906378496, 'p99_ttft_ms': 202928.22604222223, 'mean_tpot_ms': 35.070341255727556, 'median_tpot_ms': 33.90271993074225, 'std_tpot_ms': 5.1958526217693874, 'p99_tpot_ms': 45.5167203996517, 'mean_itl_ms': 35.07019149012473, 'median_itl_ms': 28.129231184720993, 'std_itl_ms': 245.95497367320698, 'p95_itl_ms': 29.61028926074505, 'p99_itl_ms': 35.236083529889754, 'concurrency': 594.9266685143978, 'accept_length': None}

with radix cache

{'backend': 'sglang', 'dataset_name': 'generated-shared-prefix', 'request_rate': inf, 'max_concurrency': None, 'sharegpt_output_len': None, 'random_input_len': 1024, 'random_output_len': 1024, 'random_range_ratio': 0.0, 'duration': 220.95910985954106, 'completed': 1024, 'total_input_tokens': 4039046, 'total_output_tokens': 839680, 'total_output_tokens_retokenized': 839290, 'request_throughput': 4.634341623891111, 'input_throughput': 18279.608397080956, 'output_throughput': 3800.1601315907114, 'mean_e2e_latency_ms': 133114.31396250374, 'median_e2e_latency_ms': 145177.0340781659, 'std_e2e_latency_ms': 59586.806921843796, 'p99_e2e_latency_ms': 220620.25247588754, 'mean_ttft_ms': 94039.92438557907, 'median_ttft_ms': 100580.67605737597, 'std_ttft_ms': 57945.60860232475, 'p99_ttft_ms': 197400.40578043088, 'mean_tpot_ms': 47.70987738330239, 'median_tpot_ms': 42.28862401516483, 'std_tpot_ms': 28.60189182856475, 'p99_tpot_ms': 227.09880975949528, 'mean_itl_ms': 47.710699266049296, 'median_itl_ms': 30.049152672290802, 'std_itl_ms': 880.109046863313, 'p95_itl_ms': 34.64645054191351, 'p99_itl_ms': 279.4529473409058, 'concurrency': 616.8972059321408, 'accept_length': None}

bench:

(server H100, tp-size=2):
python bench_serving.py
--host 127.0.0.1
--port 30000
--dataset-name generated-shared-prefix
--gsp-system-prompt-len 1000
--gsp-question-prompt-len 2744
--gsp-output-prompt-len 820
--gsp-num-groups 8
--gsp-prompts-per-group 128

log:

[[2025-10-05 18](tel:2025-10-05 18):37:30 TP0 EP0] Decode batch. #running-req: 235, #full token: 818755, full token usage: 0.99, mamba num: 470, mamba usage: 0.46, cuda graph: True, gen throughput (token/s): 7752.32, #queue-req: 789,

[[2025-10-05 18](tel:2025-10-05 18):37:31 TP0 EP0] Decode batch. #running-req: 235, #full token: 828155, full token usage: 1.00, mamba num: 470, mamba usage: 0.46, cuda graph: True, gen throughput (token/s): 7751.90, #queue-req: 789,

[[2025-10-05 18](tel:2025-10-05 18):37:31 TP0 EP0] KV cache pool is full. Retract requests. #retracted_reqs: 1, #aborted_retracted_reqs: 0, #new_token_ratio: 0.0980 -> 0.7646

[[2025-10-05 18](tel:2025-10-05 18):37:31 TP1 EP1] KV cache pool is full. Retract requests. #retracted_reqs: 1, #aborted_retracted_reqs: 0, #new_token_ratio: 0.0980 -> 0.7646

[[2025-10-05 18](tel:2025-10-05 18):37:32 TP1 EP1] KV cache pool is full. Retract requests. #retracted_reqs: 1, #aborted_retracted_reqs: 0, #new_token_ratio: 0.7506 -> 0.7829

[[2025-10-05 18](tel:2025-10-05 18):37:32 TP0 EP0] KV cache pool is full. Retract requests. #retracted_reqs: 1, #aborted_retracted_reqs: 0, #new_token_ratio: 0.7506 -> 0.7829

[[2025-10-05 18](tel:2025-10-05 18):37:32 TP0 EP0] KV cache pool is full. Retract requests. #retracted_reqs: 1, #aborted_retracted_reqs: 0, #new_token_ratio: 0.7679 -> 0.8024

[[2025-10-05 18](tel:2025-10-05 18):37:32 TP1 EP1] KV cache pool is full. Retract requests. #retracted_reqs: 1, #aborted_retracted_reqs: 0, #new_token_ratio: 0.7679 -> 0.8024

[[2025-10-05 18](tel:2025-10-05 18):37:32 TP0 EP0] Decode batch. #running-req: 232, #full token: 826859, full token usage: 1.00, mamba num: 464, mamba usage: 0.45, cuda graph: True, gen throughput (token/s): 7658.65, #queue-req: 792,

[[2025-10-05 18](tel:2025-10-05 18):37:33 TP1 EP1] KV cache pool is full. Retract requests. #retracted_reqs: 1, #aborted_retracted_reqs: 0, #new_token_ratio: 0.7884 -> 0.8207

[[2025-10-05 18](tel:2025-10-05 18):37:33 TP0 EP0] KV cache pool is full. Retract requests. #retracted_reqs: 1, #aborted_retracted_reqs: 0, #new_token_ratio: 0.7884 -> 0.8207

[[2025-10-05 18](tel:2025-10-05 18):37:33 TP0 EP0] KV cache pool is full. Retract requests. #retracted_reqs: 1, #aborted_retracted_reqs: 0, #new_token_ratio: 0.8057 -> 0.8402

[[2025-10-05 18](tel:2025-10-05 18):37:33 TP1 EP1] KV cache pool is full. Retract requests. #retracted_reqs: 1, #aborted_retracted_reqs: 0, #new_token_ratio: 0.8057 -> 0.8402

[[2025-10-05 18](tel:2025-10-05 18):37:33 TP0 EP0] Decode batch. #running-req: 230, #full token: 828932, full token usage: 1.00, mamba num: 460, mamba usage: 0.45, cuda graph: True, gen throughput (token/s): 7786.25, #queue-req: 794,

[[2025-10-05 18](tel:2025-10-05 18):37:34 TP0 EP0] KV cache pool is full. Retract requests. #retracted_reqs: 1, #aborted_retracted_reqs: 0, #new_token_ratio: 0.8262 -> 0.8585

[[2025-10-05 18](tel:2025-10-05 18):37:34 TP1 EP1] KV cache pool is full. Retract requests. #retracted_reqs: 1, #aborted_retracted_reqs: 0, #new_token_ratio: 0.8262 -> 0.8585

[[2025-10-05 18](tel:2025-10-05 18):37:34 TP0 EP0] KV cache pool is full. Retract requests. #retracted_reqs: 1, #aborted_retracted_reqs: 0, #new_token_ratio: 0.8435 -> 0.8781

[[2025-10-05 18](tel:2025-10-05 18):37:34 TP1 EP1] KV cache pool is full. Retract requests. #retracted_reqs: 1, #aborted_retracted_reqs: 0, #new_token_ratio: 0.8435 -> 0.8781

[[2025-10-05 18](tel:2025-10-05 18):37:35 TP0 EP0] KV cache pool is full. Retract requests. #retracted_reqs: 1, #aborted_retracted_reqs: 0, #new_token_ratio: 0.8630 -> 0.8976

[[2025-10-05 18](tel:2025-10-05 18):37:35 TP1 EP1] KV cache pool is full. Retract requests. #retracted_reqs: 1, #aborted_retracted_reqs: 0, #new_token_ratio: 0.8630 -> 0.8976

[[2025-10-05 18](tel:2025-10-05 18):37:35] INFO: 127.0.0.1:33552 - "GET /health HTTP/1.1" 200 OK

[[2025-10-05 18](tel:2025-10-05 18):37:35 TP0 EP0] Decode batch. #running-req: 227, #full token: 827227, full token usage: 1.00, mamba num: 454, mamba usage: 0.44, cuda graph: True, gen throughput (token/s): 7641.70, #queue-req: 797,

[[2025-10-05 18](tel:2025-10-05 18):37:35 TP0 EP0] KV cache pool is full. Retract requests. #retracted_reqs: 1, #aborted_retracted_reqs: 0, #new_token_ratio: 0.8825 -> 0.9171

[[2025-10-05 18](tel:2025-10-05 18):37:35 TP1 EP1] KV cache pool is full. Retract requests. #retracted_reqs: 1, #aborted_retracted_reqs: 0, #new_token_ratio: 0.8825 -> 0.9171

[[2025-10-05 18](tel:2025-10-05 18):37:35 TP0 EP0] KV cache pool is full. Retract requests. #retracted_reqs: 1, #aborted_retracted_reqs: 0, #new_token_ratio: 0.9020 -> 0.9366

[[2025-10-05 18](tel:2025-10-05 18):37:36 TP1 EP1] KV cache pool is full. Retract requests. #retracted_reqs: 1, #aborted_retracted_reqs: 0, #new_token_ratio: 0.9020 -> 0.9366

[[2025-10-05 18](tel:2025-10-05 18):37:36 TP0 EP0] Decode batch. #running-req: 225, #full token: 828970, full token usage: 1.00, mamba num: 450, mamba usage: 0.44, cuda graph: True, gen throughput (token/s): 7525.47, #queue-req: 799,

[[2025-10-05 18](tel:2025-10-05 18):37:36 TP1 EP1] KV cache pool is full. Retract requests. #retracted_reqs: 1, #aborted_retracted_reqs: 0, #new_token_ratio: 0.9215 -> 0.9561

[[2025-10-05 18](tel:2025-10-05 18):37:36 TP0 EP0] KV cache pool is full. Retract requests. #retracted_reqs: 1, #aborted_retracted_reqs: 0, #new_token_ratio: 0.9215 -> 0.9561

[[2025-10-05 18](tel:2025-10-05 18):37:36 TP1 EP1] KV cache pool is full. Retract requests. #retracted_reqs: 1, #aborted_retracted_reqs: 0, #new_token_ratio: 0.9400 -> 0.9768

[[2025-10-05 18](tel:2025-10-05 18):37:36 TP0 EP0] KV cache pool is full. Retract requests. #retracted_reqs: 1, #aborted_retracted_reqs: 0, #new_token_ratio: 0.9400 -> 0.9768

[[2025-10-05 18](tel:2025-10-05 18):37:37 TP0 EP0] KV cache pool is full. Retract requests. #retracted_reqs: 1, #aborted_retracted_reqs: 0, #new_token_ratio: 0.9618 -> 0.9963

[[2025-10-05 18](tel:2025-10-05 18):37:37 TP1 EP1] KV cache pool is full. Retract requests. #retracted_reqs: 1, #aborted_retracted_reqs: 0, #new_token_ratio: 0.9618 -> 0.9963

[[2025-10-05 18](tel:2025-10-05 18):37:37 TP0 EP0] Decode batch. #running-req: 222, #full token: 826827, full token usage: 1.00, mamba num: 444, mamba usage: 0.43, cuda graph: True, gen throughput (token/s): 7609.35, #queue-req: 802,

[[2025-10-05 18](tel:2025-10-05 18):37:37 TP1 EP1] KV cache pool is full. Retract requests. #retracted_reqs: 1, #aborted_retracted_reqs: 0, #new_token_ratio: 0.9803 -> 1.0000

[[2025-10-05 18](tel:2025-10-05 18):37:37 TP0 EP0] KV cache pool is full. Retract requests. #retracted_reqs: 1, #aborted_retracted_reqs: 0, #new_token_ratio: 0.9803 -> 1.0000

[[2025-10-05 18](tel:2025-10-05 18):37:38 TP0 EP0] Prefill batch. #new-seq: 4, #new-token: 15775, #cached-token: 0, full token usage: 0.00, mamba usage: 0.00, #running-req: 220, #queue-req: 799,

@yizhang2077
Copy link
Collaborator Author

You shouldn't limit the number of queries in operation with the new magic number 3 for all operating modes. You need to include the condition that radix cache is enabled.

if self.is_hybrid_gdn:
# for mamba cache radix, it need be divided by 3 (magic number now). (yizhang2077)
max_num_reqs = min(max_num_reqs, self.server_args.max_mamba_cache_size // 3)

For disable radix cache, I think max_mamba_cache_size larger than max_running_requests is also fine.

@Swipe4057
Copy link
Contributor

Swipe4057 commented Oct 6, 2025

yizhang2077 Do you have any idea why my test is so much worse than yours in performance?

Although in my test there are 8 groups of requests with the same system hints of 1000 tokens, in the log I see cached-token: 0

log:
[[2025-10-06 06](tel:2025-10-06 06):32:38 TP0 EP0] Prefill batch. #new-seq: 4, #new-token: 15821, #cached-token: 0, full token usage: 0.23, mamba usage: 0.24, #running-req: 125, #queue-req: 895,

[[2025-10-06 06](tel:2025-10-06 06):32:39 TP0 EP0] Prefill batch. #new-seq: 4, #new-token: 15801, #cached-token: 0, full token usage: 0.23, mamba usage: 0.25, #running-req: 129, #queue-req: 891,

[[2025-10-06 06](tel:2025-10-06 06):32:39 TP0 EP0] Prefill batch. #new-seq: 4, #new-token: 15785, #cached-token: 0, full token usage: 0.24, mamba usage: 0.26, #running-req: 133, #queue-req: 887,

[[2025-10-06 06](tel:2025-10-06 06):32:40 TP0 EP0] Prefill batch. #new-seq: 4, #new-token: 15748, #cached-token: 0, full token usage: 0.25, mamba usage: 0.26, #running-req: 137, #queue-req: 883,

[[2025-10-06 06](tel:2025-10-06 06):32:40 TP0 EP0] Prefill batch. #new-seq: 4, #new-token: 15791, #cached-token: 0, full token usage: 0.25, mamba usage: 0.27, #running-req: 141, #queue-req: 879,

[[2025-10-06 06](tel:2025-10-06 06):32:40 TP0 EP0] Prefill batch. #new-seq: 4, #new-token: 15719, #cached-token: 0, full token usage: 0.26, mamba usage: 0.28, #running-req: 145, #queue-req: 875,

[[2025-10-06 06](tel:2025-10-06 06):32:41 TP0 EP0] Prefill batch. #new-seq: 4, #new-token: 15655, #cached-token: 0, full token usage: 0.27, mamba usage: 0.29, #running-req: 149, #queue-req: 871,

[[2025-10-06 06](tel:2025-10-06 06):32:41 TP0 EP0] Prefill batch. #new-seq: 4, #new-token: 15803, #cached-token: 0, full token usage: 0.27, mamba usage: 0.29, #running-req: 153, #queue-req: 867,

[[2025-10-06 06](tel:2025-10-06 06):32:42 TP0 EP0] Prefill batch. #new-seq: 4, #new-token: 15746, #cached-token: 0, full token usage: 0.28, mamba usage: 0.30, #running-req: 157, #queue-req: 863,

[[2025-10-06 06](tel:2025-10-06 06):32:42 TP0 EP0] Prefill batch. #new-seq: 4, #new-token: 15783, #cached-token: 0, full token usage: 0.29, mamba usage: 0.31, #running-req: 161, #queue-req: 859,

[[2025-10-06 06](tel:2025-10-06 06):32:42 TP0 EP0] Prefill batch. #new-seq: 4, #new-token: 15755, #cached-token: 0, full token usage: 0.29, mamba usage: 0.32, #running-req: 165, #queue-req: 855,

[[2025-10-06 06](tel:2025-10-06 06):32:43 TP0 EP0] Prefill batch. #new-seq: 4, #new-token: 15784, #cached-token: 0, full token usage: 0.30, mamba usage: 0.33, #running-req: 169, #queue-req: 851,

[[2025-10-06 06](tel:2025-10-06 06):32:43 TP0 EP0] Prefill batch. #new-seq: 4, #new-token: 15750, #cached-token: 0, full token usage: 0.31, mamba usage: 0.33, #running-req: 173, #queue-req: 847,

[[2025-10-06 06](tel:2025-10-06 06):32:43 TP0 EP0] Prefill batch. #new-seq: 4, #new-token: 15765, #cached-token: 0, full token usage: 0.32, mamba usage: 0.34, #running-req: 177, #queue-req: 843,

@Swipe4057
Copy link
Contributor

You shouldn't limit the number of queries in operation with the new magic number 3 for all operating modes. You need to include the condition that radix cache is enabled.

if self.is_hybrid_gdn:
# for mamba cache radix, it need be divided by 3 (magic number now). (yizhang2077)
max_num_reqs = min(max_num_reqs, self.server_args.max_mamba_cache_size // 3)

For disable radix cache, I think max_mamba_cache_size larger than max_running_requests is also fine.

there will be OOM

@hzh0425
Copy link
Collaborator

hzh0425 commented Oct 6, 2025

yizhang2077 Do you have any idea why my test is so much worse than yours in performance?

Although in my test there are 8 groups of requests with the same system hints of 1000 tokens, in the log I see cached-token: 0

Hi @Swipe4057

It might be due to too many requests causing the cache in device to be evicted.

You can try adding --max-concurrency 5

@Swipe4057
Copy link
Contributor

yizhang2077 Do you have any idea why my test is so much worse than yours in performance?
Although in my test there are 8 groups of requests with the same system hints of 1000 tokens, in the log I see cached-token: 0

Hi @Swipe4057

It might be due to too many requests causing the cache in device to be evicted.

You can try adding --max-concurrency 5

Understood. Unfortunately, using --max-concurrency 5 is not suitable for our production environment.

Copy link
Collaborator

@hanming-lu hanming-lu left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Overall looks good! Most comments are minor, the one regarding the assert in cache_unfinished_req is more critical. Thanks!

Copy link
Collaborator

@hanming-lu hanming-lu left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Logic looks good to me!

I am not sure about the status of crashing behavior from @Swipe4057 , if it is also fixed, we are good to merge.

@Swipe4057
Copy link
Contributor

Logic looks good to me!

I am not sure about the status of crashing behavior from @Swipe4057 , if it is also fixed, we are good to merge.

I'll test the current code tomorrow.

@hanming-lu
Copy link
Collaborator

seems like mamba tree cache sanity check is not running, let's add it?

def check_tree_cache(self):
        if self.is_hybrid and isinstance(self.tree_cache, SWARadixCache):
            self.tree_cache.sanity_check()

@hanming-lu
Copy link
Collaborator

hanming-lu commented Oct 9, 2025

Scheduling related, no performance overhead if tested with python3 -m sglang.bench_serving --backend sglang --dataset-name random --num-prompts 50 --random-input 24576 --random-output 1 --random-range-ratio 1
See https://sgl-fru7574.slack.com/archives/C09DVT13FT8/p1759971353991489

I found the mamba radix cache itself has ~12% prefill overhead. Not blocking PR merge, but we shouldn't settle on the 10% perf drop :) I tested with

python3 -m sglang.bench_serving --backend sglang --dataset-name random --num-prompts 50 --random-input 20480 --random-output 1 --random-range-ratio 1

w/ radix cache
python3 -m sglang.launch_server --model Qwen/Qwen3-Next-80B-A3B-Instruct --tp 8
Input token throughput (tok/s):          64229.48

w/o radix cache
python3 -m sglang.launch_server --model Qwen/Qwen3-Next-80B-A3B-Instruct --tp 8 --disable-radix-cache
Input token throughput (tok/s):          72806.69

I am trying to experiment where it comes from. I tried:

  • comment out reset to 0 in MambaPool free
  • comment out copy_from in fork_from

Both don't make a difference.

@hzh0425
Copy link
Collaborator

hzh0425 commented Oct 9, 2025

I found the mamba radix cache itself has ~12% prefill overhead. Not blocking PR merge, but we shouldn't settle on the 10% perf drop :) I tested with

Currently, cache_unfinish_req will fork the mamba cache once, which might be introducing overhead. @hanming-lu

@Swipe4057
Copy link
Contributor

Swipe4057 commented Oct 9, 2025

I found the mamba radix cache itself has ~12% prefill overhead. Not blocking PR merge, but we shouldn't settle on the 10% perf drop :) I tested with

Currently, cache_unfinish_req will fork the mamba cache once, which might be introducing overhead. @hanming-lu

Please take a look at my test that I conducted earlier #11214 (comment)
If you specify the generated-shared-prefix dataset, the script will generate synthetic requests with specified system prompt lengths and user (variable) parts. I had 8 groups with identical system prompts.

My results (duplicating here) were as follows, on 2x H100 GPUs:
reproduce command (server H100, tp-size=2):
python bench_serving.py
--host 127.0.0.1
--port 30000
--dataset-name generated-shared-prefix
--gsp-system-prompt-len 1000
--gsp-question-prompt-len 2744
--gsp-output-prompt-len 820
--gsp-num-groups 8
--gsp-prompts-per-group 128

main, disable radix:
inputthroughput: 17906

mr, enable radix:
inputthroughput: 18279

I shared server logs, and all prefills looked like this:
Prefill batch. #new-seq: 4, #new-token: 15791, #cached-token: 0, full token usage: 0.25, mamba usage: 0.27, #running-req: 141, #queue-req: 879

That is, cached-token: 0! Although this is impossible, since cache matches should definitely exist) However, if you reduce the number of concurrent requests, for example to 5, or send requests one after another, cache matches do appear!

So I think there's a bug here.

I can't test again at the moment, our server is undergoing maintenance, I'll verify in the near future.

It would also be interesting to test something like this (--gsp-question-prompt-len 0):
python bench_serving.py
--dataset-name generated-shared-prefix
--gsp-system-prompt-len 1000
--gsp-question-prompt-len 0
--gsp-output-prompt-len 820
--gsp-num-groups 8
--gsp-prompts-per-group 128

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

Successfully merging this pull request may close these issues.

7 participants