Skip to content

Commit bf0cf51

Browse files
authored
[BugFix] fix max streaming tokens invalid (#3789)
1 parent 7e751c9 commit bf0cf51

File tree

3 files changed

+282
-9
lines changed

3 files changed

+282
-9
lines changed

fastdeploy/entrypoints/openai/serving_chat.py

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -183,6 +183,8 @@ async def chat_completion_stream_generator(
183183
else (request.metadata or {}).get("max_streaming_response_tokens", 1)
184184
) # dierctly passed & passed in metadata
185185

186+
max_streaming_response_tokens = max(1, max_streaming_response_tokens)
187+
186188
enable_thinking = request.chat_template_kwargs.get("enable_thinking") if request.chat_template_kwargs else None
187189
if enable_thinking is None:
188190
enable_thinking = request.metadata.get("enable_thinking") if request.metadata else None
@@ -370,11 +372,6 @@ async def chat_completion_stream_generator(
370372
api_server_logger.info(f"Chat Streaming response last send: {chunk.model_dump_json()}")
371373
choices = []
372374

373-
if choices:
374-
chunk.choices = choices
375-
yield f"data: {chunk.model_dump_json(exclude_unset=True)}\n\n"
376-
choices = []
377-
378375
if include_usage:
379376
completion_tokens = previous_num_tokens
380377
usage = UsageInfo(

fastdeploy/entrypoints/openai/serving_completion.py

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -331,6 +331,7 @@ async def completion_stream_generator(
331331
if request.max_streaming_response_tokens is not None
332332
else (request.suffix or {}).get("max_streaming_response_tokens", 1)
333333
) # dierctly passed & passed in suffix
334+
max_streaming_response_tokens = max(1, max_streaming_response_tokens)
334335
choices = []
335336
chunk = CompletionStreamResponse(
336337
id=request_id,
@@ -461,10 +462,6 @@ async def completion_stream_generator(
461462
)
462463
yield f"data: {usage_chunk.model_dump_json(exclude_unset=True)}\n\n"
463464
api_server_logger.info(f"Completion Streaming response last send: {chunk.model_dump_json()}")
464-
if choices:
465-
chunk.choices = choices
466-
yield f"data: {chunk.model_dump_json(exclude_unset=True)}\n\n"
467-
choices = []
468465

469466
except Exception as e:
470467
api_server_logger.error(f"Error in completion_stream_generator: {e}, {str(traceback.format_exc())}")
Lines changed: 279 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,279 @@
1+
import json
2+
import unittest
3+
from unittest import IsolatedAsyncioTestCase
4+
from unittest.mock import AsyncMock, Mock, patch
5+
6+
from fastdeploy.entrypoints.openai.protocol import (
7+
ChatCompletionRequest,
8+
CompletionRequest,
9+
)
10+
from fastdeploy.entrypoints.openai.serving_chat import OpenAIServingChat
11+
from fastdeploy.entrypoints.openai.serving_completion import OpenAIServingCompletion
12+
13+
14+
class TestMaxStreamingResponseTokens(IsolatedAsyncioTestCase):
15+
async def asyncSetUp(self):
16+
self.engine_client = Mock()
17+
self.engine_client.connection_initialized = False
18+
self.engine_client.connection_manager = AsyncMock()
19+
self.engine_client.connection_manager.initialize = AsyncMock()
20+
self.engine_client.connection_manager.get_connection = AsyncMock()
21+
self.engine_client.connection_manager.cleanup_request = AsyncMock()
22+
self.engine_client.semaphore = Mock()
23+
self.engine_client.semaphore.acquire = AsyncMock()
24+
self.engine_client.semaphore.release = Mock()
25+
self.engine_client.data_processor = Mock()
26+
self.engine_client.is_master = True
27+
28+
self.chat_serving = OpenAIServingChat(
29+
engine_client=self.engine_client,
30+
models=None,
31+
pid=123,
32+
ips=None,
33+
max_waiting_time=30,
34+
chat_template="default",
35+
enable_mm_output=False,
36+
tokenizer_base_url=None,
37+
)
38+
39+
self.completion_serving = OpenAIServingCompletion(
40+
engine_client=self.engine_client, models=None, pid=123, ips=None, max_waiting_time=30
41+
)
42+
43+
def test_metadata_parameter_setting(self):
44+
request = ChatCompletionRequest(
45+
model="test-model",
46+
messages=[{"role": "user", "content": "Hello"}],
47+
stream=True,
48+
metadata={"max_streaming_response_tokens": 100},
49+
)
50+
51+
max_tokens = (
52+
request.max_streaming_response_tokens
53+
if request.max_streaming_response_tokens is not None
54+
else (request.metadata or {}).get("max_streaming_response_tokens", 1)
55+
)
56+
57+
self.assertEqual(max_tokens, 100)
58+
59+
def test_default_value(self):
60+
request = ChatCompletionRequest(
61+
model="test-model", messages=[{"role": "user", "content": "Hello"}], stream=True
62+
)
63+
64+
max_tokens = (
65+
request.max_streaming_response_tokens
66+
if request.max_streaming_response_tokens is not None
67+
else (request.metadata or {}).get("max_streaming_response_tokens", 1)
68+
)
69+
70+
self.assertEqual(max_tokens, 1)
71+
72+
def test_edge_case_zero_value(self):
73+
request = ChatCompletionRequest(
74+
model="test-model",
75+
messages=[{"role": "user", "content": "Hello"}],
76+
stream=True,
77+
max_streaming_response_tokens=0,
78+
)
79+
80+
max_streaming_response_tokens = (
81+
request.max_streaming_response_tokens
82+
if request.max_streaming_response_tokens is not None
83+
else (request.metadata or {}).get("max_streaming_response_tokens", 1)
84+
)
85+
max_streaming_response_tokens = max(1, max_streaming_response_tokens)
86+
87+
self.assertEqual(max_streaming_response_tokens, 1)
88+
89+
@patch("fastdeploy.entrypoints.openai.serving_chat.api_server_logger")
90+
@patch("fastdeploy.entrypoints.openai.serving_chat.ChatResponseProcessor")
91+
async def test_integration_with_chat_stream_generator(self, mock_processor_class, mock_logger):
92+
response_data = [
93+
{
94+
"outputs": {"token_ids": [1], "text": "a", "top_logprobs": None},
95+
"metrics": {"first_token_time": 0.1, "inference_start_time": 0.1},
96+
"finished": False,
97+
},
98+
{
99+
"outputs": {"token_ids": [2], "text": "b", "top_logprobs": None},
100+
"metrics": {"arrival_time": 0.2, "first_token_time": None},
101+
"finished": False,
102+
},
103+
{
104+
"outputs": {"token_ids": [3], "text": "c", "top_logprobs": None},
105+
"metrics": {"arrival_time": 0.3, "first_token_time": None},
106+
"finished": False,
107+
},
108+
{
109+
"outputs": {"token_ids": [4], "text": "d", "top_logprobs": None},
110+
"metrics": {"arrival_time": 0.4, "first_token_time": None},
111+
"finished": False,
112+
},
113+
{
114+
"outputs": {"token_ids": [5], "text": "e", "top_logprobs": None},
115+
"metrics": {"arrival_time": 0.5, "first_token_time": None},
116+
"finished": False,
117+
},
118+
{
119+
"outputs": {"token_ids": [6], "text": "f", "top_logprobs": None},
120+
"metrics": {"arrival_time": 0.6, "first_token_time": None},
121+
"finished": False,
122+
},
123+
{
124+
"outputs": {"token_ids": [7], "text": "g", "top_logprobs": None},
125+
"metrics": {"arrival_time": 0.7, "first_token_time": None, "request_start_time": 0.1},
126+
"finished": True,
127+
},
128+
]
129+
130+
mock_response_queue = AsyncMock()
131+
mock_response_queue.get.side_effect = response_data
132+
133+
mock_dealer = Mock()
134+
mock_dealer.write = Mock()
135+
136+
# Mock the connection manager call
137+
self.engine_client.connection_manager.get_connection = AsyncMock(
138+
return_value=(mock_dealer, mock_response_queue)
139+
)
140+
141+
mock_processor_instance = Mock()
142+
143+
async def mock_process_response_chat_single(response, stream, enable_thinking, include_stop_str_in_output):
144+
yield response
145+
146+
mock_processor_instance.process_response_chat = mock_process_response_chat_single
147+
mock_processor_instance.enable_multimodal_content = Mock(return_value=False)
148+
mock_processor_class.return_value = mock_processor_instance
149+
150+
request = ChatCompletionRequest(
151+
model="test-model",
152+
messages=[{"role": "user", "content": "Hello"}],
153+
stream=True,
154+
max_streaming_response_tokens=3,
155+
)
156+
157+
generator = self.chat_serving.chat_completion_stream_generator(
158+
request=request,
159+
request_id="test-request-id",
160+
model_name="test-model",
161+
prompt_token_ids=[1, 2, 3],
162+
text_after_process="Hello",
163+
)
164+
165+
chunks = []
166+
async for chunk in generator:
167+
chunks.append(chunk)
168+
169+
self.assertGreater(len(chunks), 0, "No chucks!")
170+
171+
parsed_chunks = []
172+
for i, chunk_str in enumerate(chunks):
173+
if i == 0:
174+
continue
175+
if chunk_str.startswith("data: ") and chunk_str.endswith("\n\n"):
176+
json_part = chunk_str[6:-2]
177+
if json_part == "[DONE]":
178+
parsed_chunks.append({"type": "done", "raw": chunk_str})
179+
break
180+
try:
181+
chunk_dict = json.loads(json_part)
182+
parsed_chunks.append(chunk_dict)
183+
except json.JSONDecodeError as e:
184+
self.fail(f"Cannot parser {i+1} chunck, JSON: {e}\n origin string: {repr(chunk_str)}")
185+
else:
186+
self.fail(f"{i+1} chunk is unexcepted 'data: JSON\\n\\n': {repr(chunk_str)}")
187+
for chunk_dict in parsed_chunks:
188+
choices_list = chunk_dict["choices"]
189+
if choices_list[-1].get("finish_reason") is not None:
190+
break
191+
else:
192+
self.assertEqual(len(choices_list), 3, f"Chunk {chunk_dict} should has three choices")
193+
194+
found_done = any("[DONE]" in chunk for chunk in chunks)
195+
self.assertTrue(found_done, "Not Receive '[DONE]'")
196+
197+
@patch("fastdeploy.entrypoints.openai.serving_completion.api_server_logger")
198+
async def test_integration_with_completion_stream_generator(self, mock_logger):
199+
response_data = [
200+
[
201+
{
202+
"request_id": "test-request-id-0",
203+
"outputs": {"token_ids": [1], "text": "a", "top_logprobs": None},
204+
"metrics": {"first_token_time": 0.1, "inference_start_time": 0.1},
205+
"finished": False,
206+
},
207+
{
208+
"request_id": "test-request-id-0",
209+
"outputs": {"token_ids": [2], "text": "b", "top_logprobs": None},
210+
"metrics": {"arrival_time": 0.2, "first_token_time": None},
211+
"finished": False,
212+
},
213+
],
214+
[
215+
{
216+
"request_id": "test-request-id-0",
217+
"outputs": {"token_ids": [7], "text": "g", "top_logprobs": None},
218+
"metrics": {"arrival_time": 0.7, "first_token_time": None, "request_start_time": 0.1},
219+
"finished": True,
220+
}
221+
],
222+
]
223+
224+
mock_response_queue = AsyncMock()
225+
mock_response_queue.get.side_effect = response_data
226+
227+
mock_dealer = Mock()
228+
mock_dealer.write = Mock()
229+
230+
# Mock the connection manager call
231+
self.engine_client.connection_manager.get_connection = AsyncMock(
232+
return_value=(mock_dealer, mock_response_queue)
233+
)
234+
235+
request = CompletionRequest(model="test-model", prompt="Hello", stream=True, max_streaming_response_tokens=3)
236+
237+
generator = self.completion_serving.completion_stream_generator(
238+
request=request,
239+
num_choices=1,
240+
request_id="test-request-id",
241+
model_name="test-model",
242+
created_time=11,
243+
prompt_batched_token_ids=[[1, 2, 3]],
244+
text_after_process_list=["Hello"],
245+
)
246+
247+
chunks = []
248+
async for chunk in generator:
249+
chunks.append(chunk)
250+
251+
self.assertGreater(len(chunks), 0, "No chucks!")
252+
253+
parsed_chunks = []
254+
for i, chunk_str in enumerate(chunks):
255+
if chunk_str.startswith("data: ") and chunk_str.endswith("\n\n"):
256+
json_part = chunk_str[6:-2]
257+
if json_part == "[DONE]":
258+
break
259+
try:
260+
chunk_dict = json.loads(json_part)
261+
parsed_chunks.append(chunk_dict)
262+
except json.JSONDecodeError as e:
263+
self.fail(f"Cannot parser {i+1} chunck, JSON: {e}\n origin string: {repr(chunk_str)}")
264+
else:
265+
self.fail(f"{i+1} chunk is unexcepted 'data: JSON\\n\\n': {repr(chunk_str)}")
266+
self.assertEqual(len(parsed_chunks), 1)
267+
for chunk_dict in parsed_chunks:
268+
choices_list = chunk_dict["choices"]
269+
self.assertEqual(len(choices_list), 3, f"Chunk {chunk_dict} should has three choices")
270+
self.assertEqual(
271+
choices_list[-1].get("finish_reason"), "stop", f"Chunk {chunk_dict} should has stop reason"
272+
)
273+
274+
found_done = any("[DONE]" in chunk for chunk in chunks)
275+
self.assertTrue(found_done, "Not Receive '[DONE]'")
276+
277+
278+
if __name__ == "__main__":
279+
unittest.main()

0 commit comments

Comments
 (0)