Skip to content

Commit c775937

Browse files
authored
[Feature]: Support speculative decoding (#3945)
* support spec with tp and cudagraph * fa3 cudagraph * only use fa3 decode for spec decoding model * update req * add simple doc * resolve comments * use spec decode config inside pt * resolve comments * resolve comment * fix doc * fix long context infer
1 parent 21c22f0 commit c775937

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

68 files changed

+3903
-361
lines changed

benchmark/benchmark_serving.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -13,14 +13,15 @@ def get_launching_server_cmd(model_path, backend, server_config):
1313
elif backend == 'sglang':
1414
cmd = ['python3', '-m', 'sglang.launch_server', '--model-path', model_path]
1515
elif backend == 'vllm':
16-
cmd = ['vllm', 'serve', '--model', model_path]
16+
cmd = ['vllm', 'serve', model_path]
1717
else:
1818
raise ValueError(f'unknown backend: {backend}')
1919
for key, value in server_config.items():
2020
# Convert snake_case to kebab-case for command line args
2121
key = key.replace('_', '-')
2222
cmd.append(f'--{key}')
23-
cmd.append(str(value))
23+
if str(value):
24+
cmd.append(str(value))
2425
# Special handling for proxy server case
2526
if server_config.get('proxy_url') and server_config.get('dp'):
2627
cmd.append('--allow-terminate-by-client')
@@ -66,9 +67,9 @@ def get_server_ip_port(backend: str, server_config: Dict) -> Tuple[str, int]:
6667
server_ip = server_config.get('server_ip', '0.0.0.0')
6768
server_port = server_config.get('server_port', 23333)
6869
elif backend == 'sglang':
69-
return (server_config.get('server_ip', '0.0.0.0'), server_config.get('server_port', 30000))
70+
return (server_config.get('server_ip', '0.0.0.0'), server_config.get('port', 30000))
7071
elif backend == 'vllm':
71-
return (server_config.get('server_ip', '0.0.0.0'), server_config.get('server_port', 8000))
72+
return (server_config.get('server_ip', '0.0.0.0'), server_config.get('port', 8000))
7273
else:
7374
raise ValueError(f'unknown backend: {backend}')
7475
return server_ip, server_port
@@ -131,7 +132,7 @@ def benchmark(model_path: str, backend: str, server_config: Dict, data_config: D
131132

132133
try:
133134

134-
print(f"Starting api_server: {' '.join(server_cmd)}")
135+
print(f"Starting api_server: {' '.join(server_cmd)}", flush=True)
135136
proc = subprocess.Popen(server_cmd)
136137
# Wait for the server to be ready
137138
wait_server_ready(server_ip, server_port)

docs/en/advance/spec_decoding.md

Lines changed: 106 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,106 @@
1+
# Speculative Decoding
2+
3+
Speculative decoding is an optimization technique that introcude a lightweight draft model to propose multiple next tokens and then, the main model verify and choose the longest matched tokens in a forward pass. Compared with standard auto-regressive decoding, this methold lets the system generate multiple tokens at once.
4+
5+
> \[!NOTE\]
6+
> This is an experimental feature in lmdeploy.
7+
8+
## Examples
9+
10+
Here are some examples.
11+
12+
### Eagle 3
13+
14+
#### Prepare
15+
16+
Install [flash-atten3 ](https://github.com/Dao-AILab/flash-attention?tab=readme-ov-file#flashattention-3-beta-release)
17+
18+
```shell
19+
git clone --depth=1 https://github.com/Dao-AILab/flash-attention.git
20+
cd flash-attention/hopper
21+
python setup.py install
22+
```
23+
24+
#### pipeline
25+
26+
```python
27+
from lmdeploy import PytorchEngineConfig, pipeline
28+
from lmdeploy.messages import SpeculativeConfig
29+
30+
31+
if __name__ == '__main__':
32+
33+
model_path = 'meta-llama/Llama-3.1-8B-Instruct'
34+
spec_cfg = SpeculativeConfig(
35+
method='eagle3',
36+
num_speculative_tokens=3,
37+
model='yuhuili/EAGLE3-LLaMA3.1-Instruct-8B',
38+
)
39+
pipe = pipeline(model_path, backend_config=PytorchEngineConfig(max_batch_size=128), speculative_config=spec_cfg)
40+
response = pipe(['Hi, pls intro yourself', 'Shanghai is'])
41+
print(response)
42+
43+
```
44+
45+
#### serving
46+
47+
```shell
48+
lmdeploy serve api_server \
49+
meta-llama/Llama-3.1-8B-Instruct \
50+
--backend pytorch \
51+
--server-port 24545 \
52+
--speculative-draft-model yuhuili/EAGLE3-LLaMA3.1-Instruct-8B \
53+
--speculative-algorithm eagle3 \
54+
--speculative-num-draft-tokens 3 \
55+
--max-batch-size 128 \
56+
--enable-metrics
57+
```
58+
59+
### Deepseek MTP
60+
61+
#### Prepare
62+
63+
Install [FlashMLA](https://github.com/deepseek-ai/FlashMLA?tab=readme-ov-file#installation)
64+
65+
```shell
66+
git clone https://github.com/deepseek-ai/FlashMLA.git flash-mla
67+
cd flash-mla
68+
git submodule update --init --recursive
69+
pip install -v .
70+
```
71+
72+
#### pipeline
73+
74+
```python
75+
from lmdeploy import PytorchEngineConfig, pipeline
76+
from lmdeploy.messages import SpeculativeConfig
77+
78+
79+
if __name__ == '__main__':
80+
81+
model_path = 'deepseek-ai/DeepSeek-V3'
82+
spec_cfg = SpeculativeConfig(
83+
method='deepseek_mtp',
84+
num_speculative_tokens=3,
85+
)
86+
pipe = pipeline(model_path,
87+
backend_config=PytorchEngineConfig(tp=16, max_batch_size=128),
88+
speculative_config=spec_cfg)
89+
response = pipe(['Hi, pls intro yourself', 'Shanghai is'])
90+
print(response)
91+
92+
```
93+
94+
#### serving
95+
96+
```shell
97+
lmdeploy serve api_server \
98+
deepseek-ai/DeepSeek-V3 \
99+
--backend pytorch \
100+
--server-port 24545 \
101+
--tp 16 \
102+
--speculative-algorithm deepseek_mtp \
103+
--speculative-num-draft-tokens 3 \
104+
--max-batch-size 128 \
105+
--enable-metrics
106+
```

docs/en/index.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -104,6 +104,7 @@ Documentation
104104
advance/pytorch_profiling.md
105105
advance/metrics.md
106106
advance/context_parallel.md
107+
advance/spec_decoding.md
107108

108109
.. toctree::
109110
:maxdepth: 1
Lines changed: 105 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,105 @@
1+
# Speculative Decoding
2+
3+
投机解码是一种优化技术,它通过引入轻量级草稿模型来预测多个后续token,再由主模型在前向推理过程中验证并选择匹配度最高的长token序列。与标准的自回归解码相比,这种方法可使系统一次性生成多个token。
4+
5+
> \[!NOTE\]
6+
> 请注意,这是lmdeploy中的实验性功能。
7+
8+
## 示例
9+
10+
请参考如下使用示例。
11+
12+
### Eagle 3
13+
14+
#### 安装依赖
15+
16+
安装 [flash-atten3 ](https://github.com/Dao-AILab/flash-attention?tab=readme-ov-file#flashattention-3-beta-release)
17+
18+
```shell
19+
git clone --depth=1 https://github.com/Dao-AILab/flash-attention.git
20+
cd flash-attention/hopper
21+
python setup.py install
22+
```
23+
24+
#### pipeline
25+
26+
```python
27+
from lmdeploy import PytorchEngineConfig, pipeline
28+
from lmdeploy.messages import SpeculativeConfig
29+
30+
31+
if __name__ == '__main__':
32+
33+
model_path = 'meta-llama/Llama-3.1-8B-Instruct'
34+
spec_cfg = SpeculativeConfig(
35+
method='eagle3',
36+
num_speculative_tokens=3,
37+
model='yuhuili/EAGLE3-LLaMA3.1-Instruct-8B',
38+
)
39+
pipe = pipeline(model_path, backend_config=PytorchEngineConfig(max_batch_size=128), speculative_config=spec_cfg)
40+
response = pipe(['Hi, pls intro yourself', 'Shanghai is'])
41+
print(response)
42+
43+
```
44+
45+
#### serving
46+
47+
```shell
48+
lmdeploy serve api_server \
49+
meta-llama/Llama-3.1-8B-Instruct \
50+
--backend pytorch \
51+
--server-port 24545 \
52+
--speculative-draft-model yuhuili/EAGLE3-LLaMA3.1-Instruct-8B \
53+
--speculative-algorithm eagle3 \
54+
--speculative-num-draft-tokens 3 \
55+
--max-batch-size 128 \
56+
--enable-metrics
57+
```
58+
59+
### Deepseek MTP
60+
61+
#### 安装依赖
62+
63+
Install [FlashMLA](https://github.com/deepseek-ai/FlashMLA?tab=readme-ov-file#installation)
64+
65+
```shell
66+
git clone https://github.com/deepseek-ai/FlashMLA.git flash-mla
67+
cd flash-mla
68+
git submodule update --init --recursive
69+
pip install -v .
70+
```
71+
72+
#### pipeline
73+
74+
```python
75+
from lmdeploy import PytorchEngineConfig, pipeline
76+
from lmdeploy.messages import SpeculativeConfig
77+
78+
79+
if __name__ == '__main__':
80+
81+
model_path = 'deepseek-ai/DeepSeek-V3'
82+
spec_cfg = SpeculativeConfig(
83+
method='deepseek_mtp',
84+
num_speculative_tokens=3,
85+
)
86+
pipe = pipeline(model_path,
87+
backend_config=PytorchEngineConfig(tp=16, max_batch_size=128),
88+
speculative_config=spec_cfg)
89+
response = pipe(['Hi, pls intro yourself', 'Shanghai is'])
90+
print(response)
91+
```
92+
93+
#### serving
94+
95+
```shell
96+
lmdeploy serve api_server \
97+
deepseek-ai/DeepSeek-V3 \
98+
--backend pytorch \
99+
--server-port 24545 \
100+
--tp 16 \
101+
--speculative-algorithm deepseek_mtp \
102+
--speculative-num-draft-tokens 3 \
103+
--max-batch-size 128 \
104+
--enable-metrics
105+
```

docs/zh_cn/index.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -105,6 +105,7 @@ LMDeploy 工具箱提供以下核心功能:
105105
advance/pytorch_profiling.md
106106
advance/metrics.md
107107
advance/context_parallel.md
108+
advance/spec_decoding.md
108109

109110
.. toctree::
110111
:maxdepth: 1

lmdeploy/api.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
from typing import List, Literal, Optional, Union
44

55
from .archs import autoget_backend_config, get_task
6-
from .messages import PytorchEngineConfig, TurbomindEngineConfig
6+
from .messages import PytorchEngineConfig, SpeculativeConfig, TurbomindEngineConfig
77
from .model import ChatTemplateConfig
88

99

@@ -12,6 +12,7 @@ def pipeline(model_path: str,
1212
chat_template_config: Optional[ChatTemplateConfig] = None,
1313
log_level: str = 'WARNING',
1414
max_log_len: int = None,
15+
speculative_config: SpeculativeConfig = None,
1516
**kwargs):
1617
"""
1718
Args:
@@ -68,6 +69,12 @@ def pipeline(model_path: str,
6869
if backend_config is not None else None
6970
model_path = get_model(model_path, download_dir, revision)
7071

72+
# spec model
73+
if speculative_config is not None and speculative_config.model and not os.path.exists(speculative_config.model):
74+
download_dir = backend_config.download_dir \
75+
if backend_config is not None else None
76+
speculative_config.model = get_model(speculative_config.model, download_dir)
77+
7178
_, pipeline_class = get_task(model_path)
7279
if not isinstance(backend_config, PytorchEngineConfig):
7380
# set auto backend mode
@@ -80,6 +87,7 @@ def pipeline(model_path: str,
8087
backend_config=backend_config,
8188
chat_template_config=chat_template_config,
8289
max_log_len=max_log_len,
90+
speculative_config=speculative_config,
8391
**kwargs)
8492

8593

lmdeploy/cli/cli.py

Lines changed: 13 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,8 @@
33
import os
44

55
from ..version import __version__
6-
from .utils import ArgumentHelper, DefaultsAndTypesHelpFormatter, FlexibleArgumentParser, convert_args
6+
from .utils import (ArgumentHelper, DefaultsAndTypesHelpFormatter, FlexibleArgumentParser, convert_args,
7+
get_speculative_config)
78

89

910
class CLI(object):
@@ -44,12 +45,12 @@ def add_parser_chat():
4445
', "baichuan-inc/baichuan2-7b-chat" and so on')
4546
# common args
4647
ArgumentHelper.backend(parser)
47-
# # chat template args
48+
# chat template args
4849
ArgumentHelper.chat_template(parser)
4950
# model args
5051
ArgumentHelper.revision(parser)
5152
ArgumentHelper.download_dir(parser)
52-
#
53+
5354
# pytorch engine args
5455
pt_group = parser.add_argument_group('PyTorch engine arguments')
5556
ArgumentHelper.adapters(pt_group)
@@ -78,6 +79,9 @@ def add_parser_chat():
7879
ArgumentHelper.communicator(tb_group)
7980
ArgumentHelper.cp(tb_group)
8081

82+
# speculative decoding
83+
ArgumentHelper.add_spec_group(parser)
84+
8185
@staticmethod
8286
def add_parser_checkenv():
8387
"""Add parser for check_env command."""
@@ -169,7 +173,13 @@ def get_gpu_topo():
169173
@staticmethod
170174
def chat(args):
171175
from .chat import main
176+
172177
kwargs = convert_args(args)
178+
speculative_config = get_speculative_config(args)
179+
to_remove = ['speculative_algorithm', 'speculative_draft_model', 'speculative_num_draft_tokens']
180+
for key in to_remove:
181+
kwargs.pop(key)
182+
kwargs['speculative_config'] = speculative_config
173183
main(**kwargs)
174184

175185
@staticmethod

0 commit comments

Comments
 (0)