Skip to content

Commit e7cbc54

Browse files
authored
support gpt-oss function/reasoning in /v1/chat/completions (#3962)
* support gpt-oss final output * support reasoning_effort * output reasoning content * fix reasoning_effort * update * fix ut * support gpt-oss function/reasoning in /v1/chat/completions * fix lint * skip process prompt tokens * remove commentary channel when no tools are provided * update * reduce warning
1 parent a25498b commit e7cbc54

File tree

4 files changed

+193
-72
lines changed

4 files changed

+193
-72
lines changed

lmdeploy/model.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -737,7 +737,7 @@ class HFChatTemplate(BaseChatTemplate):
737737

738738
def __init__(self, model_path: str = '', **kwargs):
739739
try:
740-
from transformers import AutoTokenizer
740+
from transformers import AutoTokenizer, PretrainedConfig
741741
self.tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
742742
self.system_start, self.system_end = self._role_instruction('system')
743743
self.user_start, self.user_end = self._role_instruction('user')
@@ -747,6 +747,10 @@ def __init__(self, model_path: str = '', **kwargs):
747747
self.stop_words.append(self.tokenizer.eos_token)
748748
if hasattr(self.tokenizer, 'eot_token') and self.tokenizer.eot_token is not None:
749749
self.stop_words.append(self.tokenizer.eot_token)
750+
cfg = PretrainedConfig.from_pretrained(model_path, trust_remote_code=True)
751+
self.is_gpt_oss = getattr(cfg, 'architectures', [''])[0] == 'GptOssForCausalLM'
752+
if self.is_gpt_oss:
753+
self.stop_words.append('<|call|>')
750754
except Exception as e:
751755
raise ValueError(f'Try apply_chat_template failed: {e}')
752756

@@ -787,6 +791,9 @@ def messages2prompt(self, messages, sequence_start=True, **kwargs):
787791

788792
if messages[-1]['role'] == 'assistant' and len(self.assistant_end) > 0:
789793
prompt = prompt[:-len(self.assistant_end)] # prefix of response to let the model complete the response
794+
if self.is_gpt_oss and not kwargs.get('tools'):
795+
# for gpt-oss model, remove this seems more conducive to instruction following.
796+
prompt = prompt.replace('commentary, ', '', 1)
790797
return prompt
791798

792799
def _role_instruction(self, role):

lmdeploy/serve/openai/api_server.py

Lines changed: 92 additions & 66 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@
3030
DistServeDropConnectionRequest, DistServeInitRequest,
3131
MigrationRequest)
3232
from lmdeploy.serve.async_engine import AsyncEngine
33+
from lmdeploy.serve.openai.harmony_utils import GptOssChatParser
3334
from lmdeploy.serve.openai.protocol import ChatCompletionResponse # noqa: E501
3435
from lmdeploy.serve.openai.protocol import (ChatCompletionRequest, ChatCompletionResponseChoice,
3536
ChatCompletionResponseStreamChoice, ChatCompletionStreamResponse,
@@ -372,6 +373,9 @@ async def chat_completions_v1(request: ChatCompletionRequest, raw_request: Reque
372373
adapter_name = model_name # got a adapter name
373374
request_id = str(request.session_id)
374375
created_time = int(time.time())
376+
gpt_oss_parser = None
377+
if VariableInterface.async_engine.arch == 'GptOssForCausalLM':
378+
gpt_oss_parser = GptOssChatParser()
375379

376380
if isinstance(request.stop, str):
377381
request.stop = [request.stop]
@@ -423,12 +427,21 @@ async def chat_completions_v1(request: ChatCompletionRequest, raw_request: Reque
423427
gen_config.skip_special_tokens = False
424428
# internlm2 only uses contents inside function regardless of 'type'
425429
if not isinstance(request.tool_choice, str):
426-
tools = [
427-
item.function.model_dump() for item in request.tools
428-
if item.function.name == request.tool_choice.function.name
429-
]
430+
if gpt_oss_parser:
431+
tools = [
432+
item.model_dump() for item in request.tools
433+
if item.function.name == request.tool_choice.function.name
434+
]
435+
else:
436+
tools = [
437+
item.function.model_dump() for item in request.tools
438+
if item.function.name == request.tool_choice.function.name
439+
]
430440
else:
431-
tools = [item.function.model_dump() for item in request.tools]
441+
if gpt_oss_parser:
442+
tools = [item.model_dump() for item in request.tools]
443+
else:
444+
tools = [item.function.model_dump() for item in request.tools]
432445
# text completion for string input
433446
do_preprocess = False if isinstance(request.messages, str) else request.do_preprocess
434447
result_generator = VariableInterface.async_engine.generate(
@@ -486,46 +499,53 @@ async def completion_stream_generator() -> AsyncGenerator[str, None]:
486499
completion_tokens=res.generate_token_len,
487500
total_tokens=total_tokens,
488501
)
502+
489503
delta_token_ids = res.token_ids if res.token_ids is not None else []
490-
delta_message = DeltaMessage(role='assistant', content=res.response)
504+
if gpt_oss_parser:
505+
delta_message = gpt_oss_parser.parse_streaming(res.token_ids)
506+
if res.finish_reason == 'stop' and len(delta_message.tool_calls) > 0:
507+
res.finish_reason = 'tool_calls'
508+
else:
509+
delta_message = DeltaMessage(role='assistant', content=res.response)
510+
if has_parser:
511+
current_text = current_text + res.response
512+
current_token_ids = current_token_ids + delta_token_ids
513+
if request.tool_choice != 'none' and VariableInterface.tool_parser is not None:
514+
if res.finish_reason == 'stop' and streaming_tools is True:
515+
res.finish_reason = 'tool_calls'
516+
tool_delta = VariableInterface.tool_parser.extract_tool_calls_streaming(
517+
previous_text=previous_text,
518+
current_text=current_text,
519+
delta_text=delta_message.content,
520+
previous_token_ids=previous_token_ids,
521+
current_token_ids=current_token_ids,
522+
delta_token_ids=delta_token_ids,
523+
request=request)
524+
if tool_delta is not None:
525+
delta_message.tool_calls = tool_delta.tool_calls
526+
delta_message.content = tool_delta.content
527+
if isinstance(tool_delta.tool_calls, List) and len(tool_delta.tool_calls):
528+
streaming_tools = True
529+
elif (request.tool_choice != 'none' and request.tools is not None
530+
and VariableInterface.tool_parser is None):
531+
logger.error('Please launch the api_server with --tool-call-parser if you want to use tool.')
532+
533+
if VariableInterface.reasoning_parser is not None and request.enable_thinking is not False:
534+
reasoning_delta = VariableInterface.reasoning_parser.extract_reasoning_content_streaming(
535+
previous_text=previous_text,
536+
current_text=current_text,
537+
delta_text=delta_message.content or '',
538+
previous_token_ids=previous_token_ids,
539+
current_token_ids=current_token_ids,
540+
delta_token_ids=delta_token_ids)
541+
if reasoning_delta is not None:
542+
delta_message.reasoning_content = reasoning_delta.reasoning_content
543+
delta_message.content = reasoning_delta.content
544+
if has_parser:
545+
previous_text = current_text
546+
previous_token_ids = current_token_ids
491547
if request.return_token_ids:
492548
delta_message.gen_tokens = delta_token_ids
493-
if has_parser:
494-
current_text = current_text + res.response
495-
current_token_ids = current_token_ids + delta_token_ids
496-
if request.tool_choice != 'none' and VariableInterface.tool_parser is not None:
497-
if res.finish_reason == 'stop' and streaming_tools is True:
498-
res.finish_reason = 'tool_calls'
499-
tool_delta = VariableInterface.tool_parser.extract_tool_calls_streaming(
500-
previous_text=previous_text,
501-
current_text=current_text,
502-
delta_text=delta_message.content,
503-
previous_token_ids=previous_token_ids,
504-
current_token_ids=current_token_ids,
505-
delta_token_ids=delta_token_ids,
506-
request=request)
507-
if tool_delta is not None:
508-
delta_message.tool_calls = tool_delta.tool_calls
509-
delta_message.content = tool_delta.content
510-
if isinstance(tool_delta.tool_calls, List) and len(tool_delta.tool_calls):
511-
streaming_tools = True
512-
elif request.tool_choice != 'none' and request.tools is not None and VariableInterface.tool_parser is None:
513-
logger.error('Please launch the api_server with --tool-call-parser if you want to use tool.')
514-
515-
if VariableInterface.reasoning_parser is not None and request.enable_thinking is not False:
516-
reasoning_delta = VariableInterface.reasoning_parser.extract_reasoning_content_streaming(
517-
previous_text=previous_text,
518-
current_text=current_text,
519-
delta_text=delta_message.content or '',
520-
previous_token_ids=previous_token_ids,
521-
current_token_ids=current_token_ids,
522-
delta_token_ids=delta_token_ids)
523-
if reasoning_delta is not None:
524-
delta_message.reasoning_content = reasoning_delta.reasoning_content
525-
delta_message.content = reasoning_delta.content
526-
if has_parser:
527-
previous_text = current_text
528-
previous_token_ids = current_token_ids
529549
response_json = create_stream_response_json(index=0,
530550
delta_message=delta_message,
531551
finish_reason=res.finish_reason,
@@ -562,24 +582,34 @@ async def completion_stream_generator() -> AsyncGenerator[str, None]:
562582
cache_block_ids.append(res.cache_block_ids)
563583
remote_token_ids.append(res.token_ids)
564584

565-
tool_calls = None
566-
reasoning_content = None
567-
if request.tool_choice != 'none' and VariableInterface.tool_parser is not None:
568-
try: # TODO add json_schema guidance to turbomind
569-
tool_call_info = VariableInterface.tool_parser.extract_tool_calls(text, request=request)
570-
text, tool_calls = tool_call_info.content, tool_call_info.tool_calls
571-
if isinstance(tool_calls, List) and len(tool_calls):
572-
if final_res.finish_reason == 'stop':
573-
final_res.finish_reason = 'tool_calls'
574-
575-
except Exception as e:
576-
logger.error(f'Failed to parse {text}. Exception: {e}.')
577-
return create_error_response(HTTPStatus.BAD_REQUEST, 'Failed to parse fc related info to json format!')
578-
elif request.tool_choice != 'none' and request.tools is not None and VariableInterface.tool_parser is None:
579-
logger.error('Please launch the api_server with --tool-call-parser if you want to use tool.')
580-
581-
if VariableInterface.reasoning_parser is not None and request.enable_thinking is not False:
582-
reasoning_content, text = VariableInterface.reasoning_parser.extract_reasoning_content(text, request)
585+
if gpt_oss_parser:
586+
message = gpt_oss_parser.parse_full(final_token_ids)
587+
if final_res.finish_reason == 'stop' and len(message.tool_calls) > 0:
588+
final_res.finish_reason = 'tool_calls'
589+
else:
590+
tool_calls = None
591+
reasoning_content = None
592+
if request.tool_choice != 'none' and VariableInterface.tool_parser is not None:
593+
try: # TODO add json_schema guidance to turbomind
594+
tool_call_info = VariableInterface.tool_parser.extract_tool_calls(text, request=request)
595+
text, tool_calls = tool_call_info.content, tool_call_info.tool_calls
596+
if isinstance(tool_calls, List) and len(tool_calls):
597+
if final_res.finish_reason == 'stop':
598+
final_res.finish_reason = 'tool_calls'
599+
600+
except Exception as e:
601+
logger.error(f'Failed to parse {text}. Exception: {e}.')
602+
return create_error_response(HTTPStatus.BAD_REQUEST, 'Failed to parse fc related info to json format!')
603+
elif request.tool_choice != 'none' and request.tools is not None and VariableInterface.tool_parser is None:
604+
logger.error('Please launch the api_server with --tool-call-parser if you want to use tool.')
605+
606+
if VariableInterface.reasoning_parser is not None and request.enable_thinking is not False:
607+
reasoning_content, text = VariableInterface.reasoning_parser.extract_reasoning_content(text, request)
608+
609+
message = ChatMessage(role='assistant',
610+
content=text,
611+
tool_calls=tool_calls,
612+
reasoning_content=reasoning_content)
583613

584614
logprobs = None
585615
if gen_logprobs and len(final_logprobs):
@@ -588,15 +618,11 @@ async def completion_stream_generator() -> AsyncGenerator[str, None]:
588618

589619
assert final_res is not None
590620
choices = []
591-
chat_message = ChatMessage(role='assistant',
592-
content=text,
593-
tool_calls=tool_calls,
594-
reasoning_content=reasoning_content)
595621
if request.return_token_ids:
596-
chat_message.gen_tokens = final_token_ids
622+
message.gen_tokens = final_token_ids
597623
choice_data = ChatCompletionResponseChoice(
598624
index=0,
599-
message=chat_message,
625+
message=message,
600626
logprobs=logprobs,
601627
finish_reason=final_res.finish_reason,
602628
)
Lines changed: 88 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,88 @@
1+
# Copyright (c) OpenMMLab. All rights reserved.
2+
# Modified from https://github.com/vllm-project/vllm/blob/v0.10.2rc1/vllm/entrypoints/harmony_utils.py
3+
from typing import List
4+
5+
import shortuuid
6+
7+
from lmdeploy.serve.openai.protocol import (ChatMessage, DeltaFunctionCall, DeltaMessage, DeltaToolCall, FunctionCall,
8+
ToolCall)
9+
10+
try:
11+
from openai_harmony import HarmonyEncodingName, Role, StreamableParser, load_harmony_encoding
12+
except (ImportError, ModuleNotFoundError):
13+
pass
14+
15+
_harmony_encoding = None
16+
17+
18+
def get_encoding():
19+
global _harmony_encoding
20+
if _harmony_encoding is None:
21+
_harmony_encoding = load_harmony_encoding(HarmonyEncodingName.HARMONY_GPT_OSS)
22+
return _harmony_encoding
23+
24+
25+
def get_streamable_parser_for_assistant() -> 'StreamableParser':
26+
return StreamableParser(get_encoding(), role=Role.ASSISTANT)
27+
28+
29+
class GptOssChatParser:
30+
31+
def __init__(self):
32+
self.parser = get_streamable_parser_for_assistant()
33+
34+
def parse_streaming(self, tokens: List[int]) -> DeltaMessage:
35+
parser = self.parser
36+
delta_message = DeltaMessage(role='assistant')
37+
content = ''
38+
reasoning_content = ''
39+
tool_calls = []
40+
delta_tool_call = None
41+
for token in tokens:
42+
prev_recipient = parser.current_recipient
43+
parser.process(token)
44+
cur_channel = parser.current_channel
45+
cur_recipient = parser.current_recipient
46+
delta_text = parser.last_content_delta or ''
47+
if cur_channel == 'final':
48+
content += delta_text
49+
elif cur_channel == 'analysis':
50+
reasoning_content += delta_text
51+
elif cur_channel == 'commentary' and cur_recipient and cur_recipient.startswith('functions.'):
52+
base_index = 0
53+
for msg in parser.messages:
54+
if msg.channel == 'commentary' and msg.recipient and msg.recipient.startswith('functions.'):
55+
base_index += 1
56+
if prev_recipient != cur_recipient:
57+
if delta_tool_call is not None:
58+
tool_calls.append(delta_tool_call)
59+
tool_name = cur_recipient.split('functions.', 1)[1]
60+
delta_tool_call = DeltaToolCall(id=f'chatcmpl-tool-{shortuuid.random()}',
61+
type='function',
62+
index=base_index,
63+
function=DeltaFunctionCall(name=tool_name, arguments=''))
64+
elif delta_text:
65+
if delta_tool_call is None:
66+
delta_tool_call = DeltaToolCall(index=base_index,
67+
function=DeltaFunctionCall(arguments=delta_text))
68+
delta_tool_call.function.arguments += delta_text
69+
70+
if delta_tool_call:
71+
tool_calls.append(delta_tool_call)
72+
73+
delta_message.content = content if content else None
74+
delta_message.reasoning_content = reasoning_content if reasoning_content else None
75+
delta_message.tool_calls = tool_calls
76+
return delta_message
77+
78+
def parse_full(self, tokens: List[int]) -> ChatMessage:
79+
delta_message = self.parse_streaming(tokens)
80+
tool_calls = []
81+
for delta_tool_call in delta_message.tool_calls:
82+
function = FunctionCall(**delta_tool_call.function.model_dump())
83+
tool_calls.append(ToolCall(id=delta_tool_call.id, type=delta_tool_call.type, function=function))
84+
chat_message = ChatMessage(role='assistant',
85+
content=delta_message.content,
86+
tool_calls=tool_calls,
87+
reasoning_content=delta_message.reasoning_content)
88+
return chat_message

lmdeploy/tokenizer.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -400,9 +400,6 @@ def detokenize_incrementally(self,
400400
spaces_between_special_tokens: bool = True):
401401
if not hasattr(state, 'stream'):
402402
state.stream = self.parser()
403-
ids_offset = state.ids_offset
404-
for token_id in all_input_ids[:ids_offset]:
405-
state.stream.process(token_id)
406403

407404
response = ''
408405
stream = state.stream
@@ -423,8 +420,11 @@ class Tokenizer:
423420
"""
424421

425422
def __init__(self, model_path: str):
426-
from transformers import PretrainedConfig
427-
model_cfg = PretrainedConfig.from_pretrained(model_path, trust_remote_code=True)
423+
from transformers import AutoConfig, PretrainedConfig
424+
try:
425+
model_cfg = AutoConfig.from_pretrained(model_path, trust_remote_code=True)
426+
except Exception as e: # noqa
427+
model_cfg = PretrainedConfig.from_pretrained(model_path, trust_remote_code=True)
428428
is_gpt_oss = getattr(model_cfg, 'model_type', '') == 'gpt_oss'
429429
from transformers.models.auto.tokenization_auto import get_tokenizer_config
430430
tokenizer_config = get_tokenizer_config(model_path, trust_remote_code=True)

0 commit comments

Comments
 (0)