Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions lmdeploy/messages.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,6 +110,7 @@ class GenerationConfig:
logits_processors: Optional[List[LogitsProcessor]] = None
output_logits: Literal['all', 'generation'] = None
output_last_hidden_state: Literal['all', 'generation'] = None
include_stop_str_in_output: bool = False

# for disaggregation
with_cache: bool = False
Expand Down
14 changes: 8 additions & 6 deletions lmdeploy/serve/async_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -682,7 +682,6 @@ async def generate(
step: int = 0,
do_preprocess: bool = True,
adapter_name: Optional[str] = None,
skip_stop_tokens: bool = True,
rewind_stop_tokens: bool = False,
input_ids: Optional[List] = None,
enable_thinking: Optional[bool] = None,
Expand Down Expand Up @@ -778,9 +777,8 @@ async def generate(
def is_error(status):
return status not in [ResponseType.SUCCESS, ResponseType.FINISH, ResponseType.CANCEL]

# used to skip / rewind stop words in interactive mode
stop_ids = []
if skip_stop_tokens and not gen_config.ignore_eos:
if not gen_config.ignore_eos:
stop_ids = gen_config.stop_token_ids or []

metrics_processor.increment_total_requests()
Expand Down Expand Up @@ -864,11 +862,15 @@ def is_error(status):

if not is_error(outputs.status):
finish_reason = 'stop' if outputs.token_ids[-1] in stop_ids else 'length'
# utf-8 char at the end means it's a potential unfinished
# byte sequence
# utf-8 char at the end means it's a potential unfinished byte sequence
if not response.endswith('�'):
# avoid returning the last response twice
response = ''
token_ids = []
if gen_config.include_stop_str_in_output and finish_reason == 'stop':
# return the eos token id (MUST be in a list) and its string
token_ids = outputs.token_ids[-1:]
response = self.tokenizer.decode(token_ids, skip_special_tokens=False)
logger.info(f'session {session_id} finished, reason '
f'"{finish_reason}", input_tokens '
f'{len(input_ids)}, output_tokens {gen_len}')
Expand All @@ -877,7 +879,7 @@ def is_error(status):
len(input_ids),
gen_len,
finish_reason,
token_ids=[],
token_ids=token_ids,
cache_block_ids=outputs.cache_block_ids)
# Update a session's sequence only when it is in finished status
if outputs.status == ResponseType.FINISH:
Expand Down
1 change: 1 addition & 0 deletions lmdeploy/serve/openai/api_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -406,6 +406,7 @@ async def chat_completions_v1(request: ChatCompletionRequest, raw_request: Reque
repetition_penalty=request.repetition_penalty,
ignore_eos=request.ignore_eos,
stop_words=request.stop,
include_stop_str_in_output=request.include_stop_str_in_output,
skip_special_tokens=request.skip_special_tokens,
response_format=response_format,
logits_processors=logits_processors,
Expand Down
12 changes: 6 additions & 6 deletions lmdeploy/serve/openai/protocol.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,16 +107,15 @@ class ChatCompletionRequest(BaseModel):
"""Chat completion request."""
model: str

messages: Union[str, List[Dict[str, Any]]] = Field(examples=[[{'role': 'user', 'content': 'hi'}]]) # noqa
messages: Union[str, List[Dict[str, Any]]] = Field(examples=[[{'role': 'user', 'content': 'hi'}]])
temperature: Optional[float] = 0.7
top_p: Optional[float] = 1.0
tools: Optional[List[Tool]] = Field(default=None, examples=[None])
tool_choice: Union[ToolChoice, Literal['auto', 'required', 'none']] = Field(default='auto',
examples=['none']) # noqa
tool_choice: Union[ToolChoice, Literal['auto', 'required', 'none']] = Field(default='auto', examples=['none'])
logprobs: Optional[bool] = False
top_logprobs: Optional[int] = None
n: Optional[int] = 1
logit_bias: Optional[Dict[str, float]] = Field(default=None, examples=[None]) # noqa
logit_bias: Optional[Dict[str, float]] = Field(default=None, examples=[None])
max_completion_tokens: Optional[int] = Field(
default=None,
examples=[None],
Expand All @@ -128,15 +127,15 @@ class ChatCompletionRequest(BaseModel):
examples=[None],
deprecated='max_tokens is deprecated in favor of the max_completion_tokens field',
)
stop: Optional[Union[str, List[str]]] = Field(default=None, examples=[None]) # noqa
stop: Optional[Union[str, List[str]]] = Field(default=None, examples=[None])

stream: Optional[bool] = False
stream_options: Optional[StreamOptions] = Field(default=None, examples=[None])
presence_penalty: Optional[float] = 0.0
frequency_penalty: Optional[float] = 0.0
user: Optional[str] = None
reasoning_effort: Optional[Literal['low', 'medium', 'high']] = None
response_format: Optional[ResponseFormat] = Field(default=None, examples=[None]) # noqa
response_format: Optional[ResponseFormat] = Field(default=None, examples=[None])
# additional argument of lmdeploy
do_preprocess: Optional[bool] = True
repetition_penalty: Optional[float] = 1.0
Expand All @@ -150,6 +149,7 @@ class ChatCompletionRequest(BaseModel):
min_p: float = 0.0
enable_thinking: Optional[bool] = None
return_token_ids: Optional[bool] = False
include_stop_str_in_output: Optional[bool] = False


class FunctionCall(BaseModel):
Expand Down