Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 16 additions & 1 deletion fastdeploy/engine/request.py
Original file line number Diff line number Diff line change
Expand Up @@ -294,6 +294,7 @@ class CompletionOutput:
decode_type: int = 0
logprob: Optional[float] = None
top_logprobs: Optional[LogprobsLists] = None
draft_top_logprobs: Optional[LogprobsLists] = None
logprobs: Optional[SampleLogprobs] = None
draft_token_ids: list[int] = None
text: Optional[str] = None
Expand All @@ -308,9 +309,9 @@ def to_dict(self):
"index": self.index,
"send_idx": self.send_idx,
"token_ids": self.token_ids,
"decode_type": self.decode_type,
"logprob": self.logprob,
"top_logprobs": self.top_logprobs,
"draft_top_logprobs": self.draft_top_logprobs,
"logprobs": self.logprobs,
"draft_token_ids": self.draft_token_ids,
"text": self.text,
Expand All @@ -336,6 +337,8 @@ def __repr__(self) -> str:
f"draft_token_ids={self.draft_token_ids}, "
f"reasoning_content={self.reasoning_content!r}, "
f"logprobs={self.logprobs}, "
f"top_logprobs={self.top_logprobs}, "
f"draft_top_logprobs={self.draft_top_logprobs}, "
)


Expand Down Expand Up @@ -420,6 +423,7 @@ def __init__(
request_id: str,
prompt: Optional[str] = None,
prompt_token_ids: Optional[list[int]] = None,
output_type: Optional[int] = 3,
outputs: CompletionOutput = None,
finished: bool = False,
metrics: Optional[RequestMetrics] = None,
Expand All @@ -430,6 +434,7 @@ def __init__(
self.request_id = request_id
self.prompt = prompt
self.prompt_token_ids = prompt_token_ids
self.output_type = output_type
self.outputs = outputs
self.finished = finished
self.metrics = metrics
Expand Down Expand Up @@ -458,12 +463,21 @@ def add(self, next_output: RequestOutput) -> None:
self.outputs.top_logprobs.logprob_token_ids.extend(next_output.outputs.top_logprobs.logprob_token_ids)
self.outputs.top_logprobs.logprobs.extend(next_output.outputs.top_logprobs.logprobs)
self.outputs.top_logprobs.sampled_token_ranks.extend(next_output.outputs.top_logprobs.sampled_token_ranks)
if next_output.outputs.draft_top_logprobs is not None:
self.outputs.draft_top_logprobs.logprob_token_ids.extend(
next_output.outputs.draft_top_logprobs.logprob_token_ids
)
self.outputs.draft_top_logprobs.logprobs.extend(next_output.outputs.draft_top_logprobs.logprobs)
self.outputs.draft_top_logprobs.sampled_token_ranks.extend(
next_output.outputs.draft_top_logprobs.sampled_token_ranks
)

def __repr__(self) -> str:
return (
f"RequestOutput(request_id={self.request_id}, "
f"prompt={self.prompt!r}, "
f"prompt_token_ids={self.prompt_token_ids}, "
f"output_type={self.output_type}, "
f"outputs={self.outputs}, "
f"finished={self.finished}, "
f"num_cached_tokens={self.num_cached_tokens}, "
Expand All @@ -484,6 +498,7 @@ def to_dict(self):
"request_id": self.request_id,
"prompt": self.prompt,
"prompt_token_ids": self.prompt_token_ids,
"output_type": self.output_type,
"outputs": None if self.outputs is None else self.outputs.to_dict(),
"metrics": None if self.metrics is None else self.metrics.to_dict(),
"finished": self.finished,
Expand Down
6 changes: 6 additions & 0 deletions fastdeploy/entrypoints/openai/protocol.py
Original file line number Diff line number Diff line change
Expand Up @@ -189,6 +189,7 @@ class ChatCompletionResponseChoice(BaseModel):
index: int
message: ChatMessage
logprobs: Optional[LogProbs] = None
draft_logprobs: Optional[LogProbs] = None
finish_reason: Optional[Literal["stop", "length", "tool_calls", "recover_stop"]]


Expand Down Expand Up @@ -251,6 +252,7 @@ class ChatCompletionResponseStreamChoice(BaseModel):
index: int
delta: DeltaMessage
logprobs: Optional[LogProbs] = None
draft_logprobs: Optional[LogProbs] = None
finish_reason: Optional[Literal["stop", "length", "tool_calls"]] = None
arrival_time: Optional[float] = None

Expand Down Expand Up @@ -283,6 +285,7 @@ class CompletionResponseChoice(BaseModel):
completion_tokens: Optional[str] = None
arrival_time: Optional[float] = None
logprobs: Optional[CompletionLogprobs] = None
draft_logprobs: Optional[CompletionLogprobs] = None
reasoning_content: Optional[str] = None
finish_reason: Optional[Literal["stop", "length", "tool_calls"]]
tool_calls: Optional[List[DeltaToolCall | ToolCall]] = None
Expand Down Expand Up @@ -321,6 +324,7 @@ class CompletionResponseStreamChoice(BaseModel):
text: str
arrival_time: float = None
logprobs: Optional[CompletionLogprobs] = None
draft_logprobs: Optional[CompletionLogprobs] = None
prompt_token_ids: Optional[List[int]] = None
completion_token_ids: Optional[List[int]] = None
text_after_process: Optional[str] = None
Expand Down Expand Up @@ -410,6 +414,7 @@ class CompletionRequest(BaseModel):
echo: Optional[bool] = False
frequency_penalty: Optional[float] = Field(default=None, ge=-2, le=2)
logprobs: Optional[int] = None
include_draft_logprobs: Optional[bool] = False
# For logits and logprobs post processing
temp_scaled_logprobs: bool = False
top_p_normalized_logprobs: bool = False
Expand Down Expand Up @@ -545,6 +550,7 @@ class ChatCompletionRequest(BaseModel):
frequency_penalty: Optional[float] = Field(None, le=2, ge=-2)
logprobs: Optional[bool] = False
top_logprobs: Optional[int] = 0
include_draft_logprobs: Optional[bool] = False

# For logits and logprobs post processing
temp_scaled_logprobs: bool = False
Expand Down
23 changes: 23 additions & 0 deletions fastdeploy/entrypoints/openai/serving_chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -303,12 +303,18 @@ async def chat_completion_stream_generator(

output = res["outputs"]
output_top_logprobs = output["top_logprobs"]
output_draft_top_logprobs = output["draft_top_logprobs"]
previous_num_tokens += len(output["token_ids"])
logprobs_res: Optional[LogProbs] = None
draft_logprobs_res: Optional[LogProbs] = None
if request.logprobs and output_top_logprobs is not None:
logprobs_res = self._create_chat_logprobs(
output_top_logprobs, request.logprobs, request.top_logprobs
)
if request.include_draft_logprobs and output_draft_top_logprobs is not None:
draft_logprobs_res = self._create_chat_logprobs(
output_draft_top_logprobs, request.logprobs, request.top_logprobs
)

delta_message = DeltaMessage(
reasoning_content="",
Expand Down Expand Up @@ -336,6 +342,7 @@ async def chat_completion_stream_generator(
index=0,
delta=delta_message,
logprobs=logprobs_res,
draft_logprobs=draft_logprobs_res,
arrival_time=arrival_time,
)
if res["finished"]:
Expand Down Expand Up @@ -430,6 +437,7 @@ async def chat_completion_full_generator(
previous_num_tokens = 0
current_waiting_time = 0
logprob_contents = []
draft_logprob_contents = []
completion_token_ids = []
response_processor = ChatResponseProcessor(
data_processor=self.engine_client.data_processor,
Expand Down Expand Up @@ -476,12 +484,23 @@ async def chat_completion_full_generator(
# The logprob for handling the response
output = data["outputs"]
output_top_logprobs = output["top_logprobs"]
output_draft_top_logprobs = output["draft_top_logprobs"]
if output_top_logprobs is not None:
# logprobs
logprobs_res = self._create_chat_logprobs(
output_top_logprobs, request.logprobs, request.top_logprobs
)
if logprobs_res and logprobs_res.content is not None:
logprob_contents.extend(logprobs_res.content)

# draf_logprobs
if request.include_draft_logprobs and output_draft_top_logprobs is not None:
draft_logprobs_res = self._create_chat_logprobs(
output_draft_top_logprobs, request.logprobs, request.top_logprobs
)
if draft_logprobs_res and draft_logprobs_res.content is not None:
draft_logprob_contents.extend(draft_logprobs_res.content)

if data["finished"]:
final_res = data
task_is_finished = True
Expand Down Expand Up @@ -515,11 +534,15 @@ async def chat_completion_full_generator(
logprobs_full_res = None
if logprob_contents:
logprobs_full_res = LogProbs(content=logprob_contents)
draft_logprobs_full_res = None
if draft_logprob_contents:
draft_logprobs_full_res = LogProbs(content=draft_logprob_contents)

choice = ChatCompletionResponseChoice(
index=0,
message=message,
logprobs=logprobs_full_res,
draft_logprobs=draft_logprobs_full_res,
finish_reason=None,
)
has_no_token_limit = request.max_tokens is None and request.max_completion_tokens is None
Expand Down
26 changes: 26 additions & 0 deletions fastdeploy/entrypoints/openai/serving_completion.py
Original file line number Diff line number Diff line change
Expand Up @@ -223,6 +223,7 @@ async def completion_full_generator(
valid_results = [dict()] * num_choices
output_tokens = [0] * num_choices
aggregated_top_logprobs = [[[], [], []] for _ in range(num_choices)]
aggregated_draft_top_logprobs = [[[], [], []] for _ in range(num_choices)]
aggregated_token_ids = [[] for _ in range(num_choices)]
completion_batched_token_ids = [[] for _ in range(num_choices)]
current_waiting_time = 0
Expand Down Expand Up @@ -256,11 +257,18 @@ async def completion_full_generator(

output = data["outputs"]
output_top_logprobs = output["top_logprobs"]
output_draft_top_logprobs = output["draft_top_logprobs"]
if output_top_logprobs is not None:
aggregated_top_logprobs[rid][0].extend(output_top_logprobs[0])
aggregated_top_logprobs[rid][1].extend(output_top_logprobs[1])
aggregated_top_logprobs[rid][2].extend(output_top_logprobs[2])

# draft logprobs
if request.include_draft_logprobs and output_draft_top_logprobs is not None:
aggregated_draft_top_logprobs[rid][0].extend(output_draft_top_logprobs[0])
aggregated_draft_top_logprobs[rid][1].extend(output_draft_top_logprobs[1])
aggregated_draft_top_logprobs[rid][2].extend(output_draft_top_logprobs[2])

aggregated_token_ids[rid].extend(data["outputs"]["token_ids"])

self.engine_client.data_processor.process_response_dict(
Expand All @@ -271,6 +279,7 @@ async def completion_full_generator(
if data.get("finished", False):
data["output_token_ids"] = output_tokens[rid]
data["outputs"]["top_logprobs"] = aggregated_top_logprobs[rid]
data["outputs"]["draft_top_logprobs"] = aggregated_draft_top_logprobs[rid]
data["outputs"]["token_ids"] = aggregated_token_ids[rid]
valid_results[rid] = data
num_choices -= 1
Expand Down Expand Up @@ -423,10 +432,17 @@ async def completion_stream_generator(
await self._process_echo_logic(request, idx, res["outputs"])
output = res["outputs"]
output_top_logprobs = output["top_logprobs"]
output_draft_top_logprobs = output["draft_top_logprobs"]
logprobs_res: Optional[CompletionLogprobs] = None
draft_logprobs_res: Optional[CompletionLogprobs] = None
if request.logprobs and output_top_logprobs is not None:
logprobs_res = self._create_completion_logprobs(output_top_logprobs, request.logprobs, 0)

# draft logprobs
if request.include_draft_logprobs and output_draft_top_logprobs is not None:
draft_logprobs_res = self._create_completion_logprobs(
output_draft_top_logprobs, request.logprobs, 0
)
output_tokens[idx] += 1
delta_message = CompletionResponseStreamChoice(
index=idx,
Expand All @@ -439,6 +455,7 @@ async def completion_stream_generator(
reasoning_content="",
arrival_time=arrival_time,
logprobs=logprobs_res,
draft_logprobs=draft_logprobs_res,
)
if not res["finished"] and "delta_message" in output:
delta_message_output = output["delta_message"]
Expand Down Expand Up @@ -523,15 +540,23 @@ def request_output_to_completion_response(
final_res = final_res_batch[idx]
prompt_token_ids = prompt_batched_token_ids[idx]
assert prompt_token_ids is not None
prompt_text = request.prompt
completion_token_ids = completion_batched_token_ids[idx]

output = final_res["outputs"]
output_top_logprobs = output["top_logprobs"]
output_draft_top_logprobs = output["draft_top_logprobs"]

aggregated_logprobs: Optional[CompletionLogprobs] = None
if output_top_logprobs is not None:
aggregated_logprobs = self._create_completion_logprobs(output_top_logprobs, request.logprobs, 0)

aggregated_draft_logprobs: Optional[CompletionLogprobs] = None
if output_draft_top_logprobs is not None:
aggregated_draft_logprobs = self._create_completion_logprobs(
output_draft_top_logprobs, request.logprobs, 0
)

if request.echo:
prompt_text = self._echo_back_prompt(request, idx)
token_ids = [*prompt_token_ids, *output["token_ids"]]
Expand All @@ -554,6 +579,7 @@ def request_output_to_completion_response(
reasoning_content=output.get("reasoning_content"),
tool_calls=output.get("tool_call"),
logprobs=aggregated_logprobs,
draft_logprobs=aggregated_draft_logprobs,
finish_reason=finish_reason,
)
choices.append(choice_data)
Expand Down
Loading
Loading