Skip to content

Commit d0b6355

Browse files
authored
Initial gpt-oss support for turbomind (#3839)
* initial gpt-oss support * fix lint * fix missing include * fix cu12.8 build * guard cuda data types * guard data type
1 parent a199415 commit d0b6355

File tree

81 files changed

+1076
-4076
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

81 files changed

+1076
-4076
lines changed

CMakeLists.txt

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,8 @@ if (BUILD_TEST)
4343
Catch2
4444
GIT_REPOSITORY https://github.com/catchorg/Catch2.git
4545
GIT_TAG v3.8.0
46+
GIT_SHALLOW ON
47+
EXCLUDE_FROM_ALL
4648
)
4749
FetchContent_MakeAvailable(Catch2)
4850
endif()

lmdeploy/turbomind/deploy/config.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
# Copyright (c) OpenMMLab. All rights reserved.
22
import inspect
33
import json
4-
from dataclasses import asdict, fields
4+
from dataclasses import asdict, field, fields
55
from typing import List
66

77
# use pydantic.dataclasses.dataclass to check data type
@@ -61,6 +61,9 @@ class ModelConfig:
6161
inter_size: List[int] = None
6262
norm_eps: float = None
6363
attn_bias: int = 0
64+
mlp_bias: bool = False
65+
window_size: List[int] = field(default_factory=list)
66+
attn_sink: bool = False
6467
qk_norm: bool = False
6568
size_per_head: int = 128
6669
group_size: int = 64
@@ -70,8 +73,10 @@ class ModelConfig:
7073
mlp_tp_size: int = 1
7174
model_format: str = 'hf'
7275
expert_num: List[int] = ()
76+
expert_router_bias: bool = False
7377
expert_inter_size: int = 0
7478
experts_per_token: int = 0
79+
activation_type: str = ''
7580
moe_shared_gate: bool = False
7681
norm_topk_prob: bool = False
7782
routed_scale: float = 1.0

lmdeploy/turbomind/deploy/loader.py

Lines changed: 16 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -23,10 +23,11 @@
2323

2424
class BaseLoader(ABC):
2525

26-
def __init__(self, model_path: str, pattern):
26+
def __init__(self, model_path: str, pattern, mappings: list):
2727
self.model_path = model_path
2828
self.pattern = pattern
2929
self.item_count = defaultdict(int)
30+
self.mappings = mappings
3031

3132
def get_index(self, index_name: str, file_pattern: str) -> Tuple[dict, list]:
3233
"""Get shards and weight map (if possible) for the model."""
@@ -44,15 +45,24 @@ def get_index(self, index_name: str, file_pattern: str) -> Tuple[dict, list]:
4445
raise RuntimeError(f'failed to locate weight files for {self.model_path}')
4546
return sorted(shards), index
4647

48+
def map_key(self, key: str):
49+
if self.mappings:
50+
k = str(key)
51+
for f in self.mappings:
52+
k = f(k)
53+
return k
54+
else:
55+
return key
56+
4757
@abstractmethod
4858
def items(self) -> Iterator[Tuple[int, dict]]:
4959
pass
5060

5161

5262
class SafetensorsLoader(BaseLoader):
5363

54-
def __init__(self, model_path: str, pattern: str, index_name=None, file_pattern=None):
55-
super().__init__(model_path, pattern)
64+
def __init__(self, model_path: str, pattern: str, mappings: list, index_name=None, file_pattern=None):
65+
super().__init__(model_path, pattern, mappings)
5666
self.shards, index = self.get_index(index_name, file_pattern)
5767
if not index:
5868
# there is no model.safetensors.index.json in the model_path,
@@ -87,7 +97,7 @@ def items(self):
8797
else:
8898
idx = int(match[0])
8999
param = params[idx]
90-
param[k] = f.get_tensor(k)
100+
param[self.map_key(k)] = f.get_tensor(k)
91101
if len(param) == self.item_count[idx]:
92102
yield (idx, params.pop(idx))
93103
if misc:
@@ -164,8 +174,8 @@ def items(self):
164174
self.que.task_done()
165175

166176

167-
def create_loader(model_path: Union[str, Queue], pattern: str) -> BaseLoader:
168-
args = (model_path, pattern)
177+
def create_loader(model_path: Union[str, Queue], pattern: str, mappings: list) -> BaseLoader:
178+
args = (model_path, pattern, mappings)
169179

170180
if isinstance(model_path, Queue):
171181
# used for `update_params`

lmdeploy/turbomind/deploy/module.py

Lines changed: 16 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,8 @@ def pad_out_dims(x: torch.Tensor, dims: int):
5050

5151

5252
def pad_in_dims(x: torch.Tensor, dims: int):
53+
if x.dim() == 1: # 1-dim object does not have input dim (e.g. bias)
54+
return x
5355
pad = dims - x.size(0)
5456
assert x.dim() == 2
5557
assert pad >= 0
@@ -119,6 +121,8 @@ def _export(self, inter_size: int, fmt: str, idx: int, w123, kind: str, pack_fn,
119121
self.model.save_split(w2, fmt.format(idx, 'w2', kind), split_dim=0, split_num=self.tp, copy=is_lora_b)
120122

121123
def apply(self, i: int, r: BaseReader):
124+
if not self.inter_size[i]:
125+
return
122126
for e in get_params(r.ffn(i, None)):
123127
e(partial(self._export, self.inter_size[i], self._ffn), partial(r.ffn, i), i)
124128

@@ -132,7 +136,7 @@ class MoeFfn(Ffn):
132136
"""
133137

134138
_moe_ffn_expert = 'layers.{0}.moe_ffn.experts.E.{1}.{2}'
135-
_moe_ffn_gate = 'layers.{0}.moe_ffn.gate.weight'
139+
_moe_ffn_gate = 'layers.{0}.moe_ffn.gate.{1}'
136140
_moe_ffn_shared_gate = 'layers.{0}.moe_ffn.shared_gate.weight'
137141

138142
def __init__(self, model: BaseOutputModel):
@@ -144,17 +148,20 @@ def __init__(self, model: BaseOutputModel):
144148
def apply(self, i: int, r: BaseReader):
145149
if self.expert_num[i] == 0:
146150
return
147-
for p in get_params(r.moe_ffn_expert()):
151+
for p in get_params(r.moe_ffn_expert(), 1):
148152
for e in range(self.expert_num[i]):
149153
fmt = self._moe_ffn_expert.replace('E', str(e))
150154
p(partial(self._export, self.inter_size, fmt), partial(r.moe_ffn_expert, e, i), i)
151155

152-
gate = transpose(r.moe_ffn_gate(i))
153-
self.model.save_split(gate, self._moe_ffn_gate.format(i))
156+
# router
157+
gate = transpose(r.moe_ffn_gate(i, 'weight'))
158+
self.model.save_split(gate, self._moe_ffn_gate.format(i, 'weight'))
159+
bias = r.moe_ffn_gate(i, 'bias')
160+
if bias is not None:
161+
self.model.save_split(bias, self._moe_ffn_gate.format(i, 'bias'))
154162

155163
if self.shared_gate:
156164
shared_gate = transpose(r.moe_ffn_shared_gate(i))
157-
# print(shared_gate)
158165
self.model.save_split(shared_gate, self._moe_ffn_shared_gate.format(i))
159166

160167

@@ -172,6 +179,7 @@ def __init__(self, model: BaseOutputModel):
172179
self.head_dim = model.model_config.size_per_head
173180
self.attn_bias = model.model_config.attn_bias
174181
self.qk_norm = model.model_config.qk_norm
182+
self.attn_sink = model.model_config.attn_sink
175183
self.group_size = max(1, model.model_config.group_size)
176184

177185
def _reorder_and_merge(self, qkvo, gs: int):
@@ -250,6 +258,9 @@ def apply(self, i: int, r: BaseReader):
250258
k = permute_v2(k, self.head_dim)
251259
self.model.save_split(q, self._attn.format(i, 'q_norm', '')[:-1])
252260
self.model.save_split(k, self._attn.format(i, 'k_norm', '')[:-1])
261+
if self.attn_sink:
262+
sinks = r.attn_sinks(i)
263+
self.model.save_split(sinks, self._attn.format(i, 'sinks', '')[:-1], split_dim=0, split_num=self.tp)
253264

254265

255266
class MLA(Module):

lmdeploy/turbomind/deploy/source_model/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
from .deepseek2 import DeepSeek2Model # noqa: F401
44
from .deepseek_vl import DeepSeekVLModel # noqa: F401
55
from .glm4 import Glm4Model # noqa: F401
6+
from .gpt_oss import GptOssModel # noqa: F401
67
from .internlm2 import InternLM2Model # noqa: F401
78
from .internvl import InternVLModel # noqa: F401
89
from .llama import LlamaModel # noqa: F401

lmdeploy/turbomind/deploy/source_model/deepseek2.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,8 +8,8 @@
88

99
class DeepSeek2Reader(LlamaReader):
1010

11-
def moe_ffn_gate(self, i):
12-
return self.params.get(f'model.layers.{i}.mlp.gate.weight')
11+
def moe_ffn_gate(self, i, kind):
12+
return self.params.get(f'model.layers.{i}.mlp.gate.{kind}')
1313

1414
def moe_ffn_expert(self, e=None, i=None, kind=None):
1515
if not kind:
Lines changed: 64 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,64 @@
1+
# Copyright (c) OpenMMLab. All rights reserved.
2+
3+
import re
4+
5+
from .base import INPUT_MODELS
6+
from .llama import LlamaModel, LlamaReader
7+
8+
9+
def map_experts(str):
10+
s = re.sub(r'(experts.*proj)$', r'\1.weight', str)
11+
s = re.sub(r'(experts.*proj)_bias$', r'\1.bias', s)
12+
return s
13+
14+
15+
class GptOssReader(LlamaReader):
16+
17+
mappings = [map_experts]
18+
19+
def moe_ffn_expert(self, e=None, i=None, kind=None):
20+
if not kind:
21+
return self.filter(r'experts')
22+
result = []
23+
for key in ['gate_up', 'down']:
24+
name = f'{self.attn_layer_prefix}.{i}.mlp.experts.{key}_proj.{kind}'
25+
tensor = self.params.get(name)[e]
26+
if tensor.ndim == 2:
27+
tensor = tensor.cuda().t() # experts in unsloth/gpt-oss-20b-BF16 are transposed
28+
if key == 'gate_up':
29+
gate, up = tensor[::2], tensor[1::2]
30+
result.append(self.transform(gate, kind))
31+
result.append(self.transform(up, kind))
32+
else:
33+
result.append(self.transform(tensor, kind))
34+
return (result[0], result[2], result[1])
35+
36+
def moe_ffn_gate(self, i, kind):
37+
return self.transform(self.params.get(f'{self.attn_layer_prefix}.{i}.mlp.router.{kind}'), kind)
38+
39+
def attn_sinks(self, i):
40+
return self.params.get(f'{self.attn_layer_prefix}.{i}.self_attn.sinks')
41+
42+
43+
@INPUT_MODELS.register_module(name='gpt-oss')
44+
class GptOssModel(LlamaModel):
45+
46+
Reader = GptOssReader
47+
48+
def model_info(self):
49+
cfg = self.model_config
50+
types = cfg['layer_types']
51+
sliding_window = cfg['sliding_window']
52+
info = super().model_info()
53+
info.update(attn_bias=int(cfg['attention_bias']),
54+
mlp_bias=True,
55+
expert_router_bias=True,
56+
expert_num=cfg['num_local_experts'],
57+
expert_inter_size=cfg['intermediate_size'],
58+
experts_per_token=cfg['experts_per_token'],
59+
norm_topk_prob=True,
60+
inter_size=0,
61+
window_size=[sliding_window if x == 'sliding_attention' else 0 for x in types],
62+
attn_sink=True,
63+
activation_type='gpt-oss')
64+
return info

lmdeploy/turbomind/deploy/source_model/llama.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -108,7 +108,8 @@ def __init__(self, model_path: str, tokenizer_path: str, **kwargs: dict):
108108
self.model_config = self.model_config.to_dict()
109109

110110
def readers(self):
111-
loader = create_loader(self.model_path, self.Reader.attn_layer_patten)
111+
mappings = getattr(self.Reader, 'mappings', [])
112+
loader = create_loader(self.model_path, self.Reader.attn_layer_patten, mappings)
112113
for i, param in loader.items():
113114
reader = self.Reader(param, {}, False, self.model_config, policy=self.policy)
114115
yield i, reader

lmdeploy/turbomind/deploy/source_model/mixtral.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,8 +17,8 @@ def moe_ffn_expert(self, e=None, i=None, kind=None):
1717
result.append(tensor)
1818
return (*result, )
1919

20-
def moe_ffn_gate(self, i):
21-
return self.params.get(f'model.layers.{i}.block_sparse_moe.gate.weight')
20+
def moe_ffn_gate(self, i, kind):
21+
return self.params.get(f'model.layers.{i}.block_sparse_moe.gate.{kind}')
2222

2323

2424
@INPUT_MODELS.register_module(name='mixtral')

lmdeploy/turbomind/deploy/source_model/qwen.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -130,8 +130,8 @@ def moe_ffn_expert(self, e=None, i=None, kind=None):
130130
result.append(tensor)
131131
return (*result, )
132132

133-
def moe_ffn_gate(self, i):
134-
return self.transform(self.params.get(f'{self.attn_layer_prefix}.{i}.mlp.gate.weight'), 'weight')
133+
def moe_ffn_gate(self, i, kind):
134+
return self.transform(self.params.get(f'{self.attn_layer_prefix}.{i}.mlp.gate.{kind}'), kind)
135135

136136
def _ffn(self, i: int, kind: str):
137137
"""Get ffn kind for layer i."""

0 commit comments

Comments
 (0)