Skip to content

Conversation

@lgeiger
Copy link
Contributor

@lgeiger lgeiger commented Oct 29, 2025

Purpose

This is a followup to #23207 and adds torch.compile support to Qwen3VL. I'm keeping it as a draft PR until I had time to run some benchmarks and correctness tests later this week.
/cc @Lucaskabela

Test Plan

Test Result

@mergify mergify bot added the qwen Related to Qwen models label Oct 29, 2025
@lgeiger
Copy link
Contributor Author

lgeiger commented Nov 4, 2025

I ran some benchmarks on a L40s and it looks like this change would increase memory usage.

Previously I was able to run Qwen3-VL-30B-A3B-Instruct-FP8 with a maximum model length of 113872:

vllm serve Qwen/Qwen3-VL-30B-A3B-Instruct-FP8 --limit-mm-per-prompt.video 0 --gpu-memory-utilization 0.985 --max-model-len 113872

With this PR it seems like the maximum model length would decrease to 64624 once I enable compilation (I had to also decrease the gpu-memory-utilization setting a bit):

vllm serve Qwen/Qwen3-VL-30B-A3B-Instruct-FP8 --limit-mm-per-prompt.video 0 --gpu-memory-utilization 0.96 --max-model-len 64624

@Lucaskabela Have you seen a similar behaviour for Qwen2.5 VL?

Performance wise it also looks like throughput is worse:

vllm bench serve --backend openai-chat --model Qwen/Qwen3-VL-30B-A3B-Instruct-FP8 --endpoint /v1/chat/completions --dataset-name hf --dataset-path lmarena-ai/VisionArena-Chat --hf-split train --num-prompts 1000

main:

============ Serving Benchmark Result ============
Successful requests:                     998
Failed requests:                         2
Benchmark duration (s):                  110.06
Total input tokens:                      94304
Total generated tokens:                  119669
Request throughput (req/s):              9.07
Output token throughput (tok/s):         1087.34
Peak output token throughput (tok/s):    6141.00
Peak concurrent requests:                998.00
Total Token throughput (tok/s):          1944.20
---------------Time to First Token----------------
Mean TTFT (ms):                          46956.38
Median TTFT (ms):                        43556.70
P99 TTFT (ms):                           103163.55
-----Time per Output Token (excl. 1st token)------
Mean TPOT (ms):                          125.36
Median TPOT (ms):                        124.58
P99 TPOT (ms):                           215.47
---------------Inter-token Latency----------------
Mean ITL (ms):                           122.37
Median ITL (ms):                         75.48
P99 ITL (ms):                            435.31
==================================================

torch compiled:

============ Serving Benchmark Result ============
Successful requests:                     998
Failed requests:                         2
Benchmark duration (s):                  122.11
Total input tokens:                      94303
Total generated tokens:                  119536
Request throughput (req/s):              8.17
Output token throughput (tok/s):         978.92
Peak output token throughput (tok/s):    2008.00
Peak concurrent requests:                998.00
Total Token throughput (tok/s):          1751.21
---------------Time to First Token----------------
Mean TTFT (ms):                          53149.79
Median TTFT (ms):                        54361.32
P99 TTFT (ms):                           110735.90
-----Time per Output Token (excl. 1st token)------
Mean TPOT (ms):                          99.68
Median TPOT (ms):                        95.86
P99 TPOT (ms):                           197.06
---------------Inter-token Latency----------------
Mean ITL (ms):                           104.17
Median ITL (ms):                         63.71
P99 ITL (ms):                            434.82
==================================================

@Lucaskabela
Copy link
Contributor

Hm I didn't observe the model length issues in my previous PR, as memory usage shouldn't increase during runtime (just compile time, unless we are doing some tricks here); the throughput decrease also seems odd to me since the Time per Output and ITL are both improving; seems the TTFT is dropping a bit here

I wonder if there is some dimension we need to mark dynamic here? If we are recompiling, this could explain the higher TTFT/memory increase

@Lucaskabela
Copy link
Contributor

One way we can check is running tlparse and looking at the logs - can you try prefixing your command with

TORCH_TRACE=/tmp/logs, then after running, use tlparse on the logs? (tlparse /tmp/logs) This should show us if there are recompiles that aren't expected

I will try to look at this tomorrow, but am also trying to get some vLLM changes into the 2.9.1 pytorch release so may not be able to get to it; will update after investigating on my end

@lgeiger
Copy link
Contributor Author

lgeiger commented Nov 4, 2025

I will try to look at this tomorrow, but am also trying to get some vLLM changes into the 2.9.1 pytorch release so may not be able to get to it; will update after investigating on my end

All good. I'm just documenting it here. I'll also have a look when I have time later this or next week.

@Lucaskabela
Copy link
Contributor

I also wonder if the FP8 extension could be contributing to this overhead? I haven't looked much into how this quantization interplays with compile

@Lucaskabela
Copy link
Contributor

Lucaskabela commented Nov 6, 2025

Running a warmed up (run benchmark twice, take the second one) model, I got:

============ Serving Benchmark Result ============
Successful requests:                     986       
Failed requests:                         14        
Benchmark duration (s):                  152.16    
Total input tokens:                      39529     
Total generated tokens:                  118940    
Request throughput (req/s):              6.48      
Output token throughput (tok/s):         781.69    
Peak output token throughput (tok/s):    965.00    
Peak concurrent requests:                986.00    
Total Token throughput (tok/s):          1041.48   
---------------Time to First Token----------------
Mean TTFT (ms):                          75803.62  
Median TTFT (ms):                        76680.59  
P99 TTFT (ms):                           148796.56 
-----Time per Output Token (excl. 1st token)------
Mean TPOT (ms):                          19.46     
Median TPOT (ms):                        19.83     
P99 TPOT (ms):                           24.15     
---------------Inter-token Latency----------------
Mean ITL (ms):                           21.44     
Median ITL (ms):                         14.71     
P99 ITL (ms):                            67.58     
==================================================

vs

============ Serving Benchmark Result ============
Successful requests:                     986       
Failed requests:                         14        
Benchmark duration (s):                  148.46    
Total input tokens:                      39529     
Total generated tokens:                  118327    
Request throughput (req/s):              6.64      
Output token throughput (tok/s):         797.01    
Peak output token throughput (tok/s):    995.00    
Peak concurrent requests:                986.00    
Total Token throughput (tok/s):          1063.27   
---------------Time to First Token----------------
Mean TTFT (ms):                          72104.10  
Median TTFT (ms):                        71299.37  
P99 TTFT (ms):                           145167.22 
-----Time per Output Token (excl. 1st token)------
Mean TPOT (ms):                          19.25     
Median TPOT (ms):                        19.62     
P99 TPOT (ms):                           24.10     
---------------Inter-token Latency----------------
Mean ITL (ms):                           21.53     
Median ITL (ms):                         14.69     
P99 ITL (ms):                            64.98     

I think this supports my idea the current integration may have some recompile happening first. I didn't observe the same size issues but couldn't run the command you provided on main so had to reduce my seq_len size to fit on my machine. Will investigate to see about recompiles with tlparse

@Lucaskabela
Copy link
Contributor

So I tried running TORCH_TRACE=/tmp/log with-proxy python examples/offline_inference/vision_language.py -m qwen3_vl then tlparse /tmp/log to get more insight - this command crashed with an inductor bug, which would indeed suggest there is some issue in compiling here! I will try to dig deeper but hopefully this gives ideas where to look

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

qwen Related to Qwen models

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants