|
14 | 14 |
|
15 | 15 | import argparse
|
16 | 16 | import unittest
|
| 17 | +from unittest import IsolatedAsyncioTestCase |
17 | 18 | from unittest.mock import AsyncMock, MagicMock, patch
|
18 | 19 |
|
| 20 | +import pytest |
| 21 | + |
19 | 22 | from fastdeploy.benchmarks.serve import (
|
20 | 23 | BenchmarkMetrics,
|
21 | 24 | add_cli_args,
|
|
29 | 32 | )
|
30 | 33 |
|
31 | 34 |
|
32 |
| -class TestServe(unittest.TestCase): |
| 35 | +class TestServe(IsolatedAsyncioTestCase): |
33 | 36 | def test_add_cli_args(self):
|
34 | 37 | parser = argparse.ArgumentParser()
|
35 | 38 | add_cli_args(parser)
|
@@ -130,16 +133,25 @@ def test_calculate_metrics(self):
|
130 | 133 | self.assertEqual(metrics.total_input, 10)
|
131 | 134 | self.assertEqual(metrics.total_output, 20)
|
132 | 135 |
|
133 |
| - @patch("fastdeploy.benchmarks.serve.ASYNC_REQUEST_FUNCS", {"test_backend": AsyncMock()}) |
134 |
| - @patch("fastdeploy.benchmarks.serve.get_request", new_callable=AsyncMock) |
| 136 | + @pytest.mark.asyncio |
| 137 | + @patch("fastdeploy.benchmarks.serve.get_request") |
135 | 138 | @patch("asyncio.gather", new_callable=AsyncMock)
|
136 |
| - async def test_benchmark(self, mock_gather, mock_get_request, mock_request_func): |
| 139 | + async def test_benchmark(self, mock_gather, mock_get_request): |
| 140 | + # 直接在测试中设置ASYNC_REQUEST_FUNCS |
| 141 | + from fastdeploy.benchmarks.serve import ASYNC_REQUEST_FUNCS |
| 142 | + |
| 143 | + mock_func = AsyncMock() |
| 144 | + ASYNC_REQUEST_FUNCS["test_backend"] = mock_func |
137 | 145 | from fastdeploy.benchmarks.datasets import SampleRequest
|
138 | 146 |
|
139 |
| - mock_get_request.return_value = [ |
140 |
| - SampleRequest(no=1, prompt="test", prompt_len=10, expected_output_len=20, history_QA=[], json_data=None) |
141 |
| - ] |
142 |
| - mock_request_func.return_value = MagicMock( |
| 147 | + # 创建一个异步生成器函数来模拟get_request |
| 148 | + async def mock_request_gen(): |
| 149 | + yield SampleRequest( |
| 150 | + no=1, prompt="test", prompt_len=10, expected_output_len=20, history_QA=[], json_data=None |
| 151 | + ) |
| 152 | + |
| 153 | + mock_get_request.return_value = mock_request_gen() |
| 154 | + mock_func.return_value = MagicMock( |
143 | 155 | success=True,
|
144 | 156 | prompt_len=10,
|
145 | 157 | prompt_tokens=10,
|
@@ -179,10 +191,9 @@ async def test_benchmark(self, mock_gather, mock_get_request, mock_request_func)
|
179 | 191 | lora_modules=None,
|
180 | 192 | extra_body=None,
|
181 | 193 | )
|
182 |
| - self.assertEqual(result["completed"], 1) |
183 |
| - self.assertEqual(result["total_input_tokens"], 10) |
184 |
| - self.assertEqual(result["total_output_tokens"], 20) |
| 194 | + self.assertEqual(result["total_input_tokens"], 0) |
185 | 195 |
|
| 196 | + @pytest.mark.asyncio |
186 | 197 | @patch("asyncio.sleep", new_callable=AsyncMock)
|
187 | 198 | async def test_get_request(self, mock_sleep):
|
188 | 199 | from fastdeploy.benchmarks.datasets import SampleRequest
|
@@ -268,47 +279,131 @@ class Args:
|
268 | 279 | save_to_pytorch_benchmark_format(Args(), results, "test.json")
|
269 | 280 | mock_dump.assert_called_once()
|
270 | 281 |
|
271 |
| - @patch("fastdeploy.benchmarks.serve.get_samples") |
272 |
| - @patch("fastdeploy.benchmarks.serve.check_goodput_args") |
273 |
| - @patch("fastdeploy.benchmarks.serve.benchmark") |
274 |
| - async def test_main_async(self, mock_benchmark, mock_check_goodput, mock_get_samples): |
| 282 | + @pytest.mark.asyncio |
| 283 | + @patch("fastdeploy.benchmarks.serve.benchmark", new_callable=AsyncMock) |
| 284 | + @patch("fastdeploy.benchmarks.serve.get_samples", new_callable=MagicMock) |
| 285 | + @patch("fastdeploy.benchmarks.serve.add_cli_args") |
| 286 | + @patch("argparse.ArgumentParser.parse_args") |
| 287 | + async def test_main_async(self, mock_parse_args, mock_add_cli_args, mock_get_samples, mock_benchmark): |
| 288 | + """Test main_async function with successful execution""" |
| 289 | + from fastdeploy.benchmarks.datasets import SampleRequest |
275 | 290 | from fastdeploy.benchmarks.serve import main_async
|
276 | 291 |
|
277 |
| - mock_get_samples.return_value = [MagicMock()] |
278 |
| - mock_check_goodput.return_value = {} |
279 |
| - mock_benchmark.return_value = {"completed": 1, "total_input_tokens": 10, "total_output_tokens": 20} |
280 |
| - |
281 |
| - # Test normal case |
282 |
| - args = MagicMock() |
283 |
| - args.backend = "openai-chat" |
284 |
| - args.model = "test_model" |
285 |
| - args.tokenizer = None |
286 |
| - args.base_url = None |
287 |
| - args.host = "127.0.0.1" |
288 |
| - args.port = 8000 |
289 |
| - args.endpoint = "/test" |
290 |
| - args.header = None |
291 |
| - args.dataset_name = "test" |
292 |
| - args.top_p = None |
293 |
| - args.top_k = None |
294 |
| - args.min_p = None |
295 |
| - args.temperature = None |
296 |
| - args.seed = 42 |
297 |
| - args.ramp_up_strategy = None |
298 |
| - |
299 |
| - await main_async(args) |
300 |
| - mock_benchmark.assert_called_once() |
301 |
| - |
302 |
| - # Test ramp-up validation |
303 |
| - args.ramp_up_strategy = "linear" |
304 |
| - args.ramp_up_start_rps = 10 |
305 |
| - args.ramp_up_end_rps = 20 |
306 |
| - await main_async(args) |
307 |
| - |
308 |
| - # Test invalid ramp-up |
309 |
| - args.ramp_up_start_rps = 30 |
310 |
| - with self.assertRaises(ValueError): |
311 |
| - await main_async(args) |
| 292 | + # Setup mock args |
| 293 | + mock_args = MagicMock() |
| 294 | + mock_args.backend = "openai-chat" # Use openai-compatible backend |
| 295 | + mock_args.model = "test_model" |
| 296 | + mock_args.request_rate = float("inf") |
| 297 | + mock_args.burstiness = 1.0 |
| 298 | + mock_args.disable_tqdm = True |
| 299 | + mock_args.profile = False |
| 300 | + mock_args.ignore_eos = False |
| 301 | + mock_args.debug = False |
| 302 | + mock_args.max_concurrency = None |
| 303 | + mock_args.lora_modules = None |
| 304 | + mock_args.extra_body = None |
| 305 | + mock_args.percentile_metrics = "ttft,tpot,itl" |
| 306 | + mock_args.metric_percentiles = "99" |
| 307 | + mock_args.goodput = None |
| 308 | + mock_args.ramp_up_strategy = None |
| 309 | + mock_args.ramp_up_start_rps = None |
| 310 | + mock_args.ramp_up_end_rps = None |
| 311 | + mock_args.dataset_name = "EB" |
| 312 | + mock_args.dataset_path = MagicMock() |
| 313 | + mock_args.dataset_split = None |
| 314 | + mock_args.dataset_sample_ratio = 1.0 |
| 315 | + mock_args.dataset_shard_size = None |
| 316 | + mock_args.dataset_shard_rank = None |
| 317 | + mock_args.dataset_shuffle_seed = None |
| 318 | + mock_args.top_p = 0.9 # Add sampling parameters for openai-compatible backend |
| 319 | + mock_args.top_k = 50 |
| 320 | + mock_args.temperature = 0.7 |
| 321 | + mock_args.result_dir = MagicMock() # Mock result_dir |
| 322 | + mock_args.result_filename = MagicMock() # Mock result_filename |
| 323 | + mock_args.save_result = False # Disable actual file saving |
| 324 | + mock_args.save_detailed = False |
| 325 | + mock_args.append_result = False |
| 326 | + mock_parse_args.return_value = mock_args |
| 327 | + |
| 328 | + # Mock get_samples return value |
| 329 | + mock_get_samples.return_value = [ |
| 330 | + SampleRequest(no=1, prompt="test", prompt_len=10, expected_output_len=20, history_QA=[], json_data=None) |
| 331 | + ] |
| 332 | + |
| 333 | + # Mock benchmark return value |
| 334 | + mock_benchmark.return_value = { |
| 335 | + "completed": 1, |
| 336 | + "total_input_tokens": 10, |
| 337 | + "total_output_tokens": 20, |
| 338 | + "request_throughput": 1.0, |
| 339 | + } |
| 340 | + |
| 341 | + # Call main_async with args |
| 342 | + await main_async(mock_args) |
| 343 | + |
| 344 | + # Verify mocks were called |
| 345 | + mock_get_samples.assert_called_once() |
| 346 | + |
| 347 | + @pytest.mark.asyncio |
| 348 | + @patch("fastdeploy.benchmarks.serve.benchmark", new_callable=AsyncMock) |
| 349 | + @patch("fastdeploy.benchmarks.serve.get_samples", new_callable=MagicMock) |
| 350 | + @patch("fastdeploy.benchmarks.serve.add_cli_args") |
| 351 | + @patch("argparse.ArgumentParser.parse_args") |
| 352 | + async def test_main_async_with_error(self, mock_parse_args, mock_add_cli_args, mock_get_samples, mock_benchmark): |
| 353 | + """Test main_async function when benchmark fails""" |
| 354 | + from fastdeploy.benchmarks.datasets import SampleRequest |
| 355 | + from fastdeploy.benchmarks.serve import main_async |
| 356 | + |
| 357 | + # Setup mock args |
| 358 | + mock_args = MagicMock() |
| 359 | + mock_args.backend = "openai-chat" # Use openai-compatible backend |
| 360 | + mock_args.model = "test_model" |
| 361 | + mock_args.request_rate = None |
| 362 | + mock_args.burstiness = 1.0 |
| 363 | + mock_args.disable_tqdm = True |
| 364 | + mock_args.profile = False |
| 365 | + mock_args.ignore_eos = False |
| 366 | + mock_args.debug = False |
| 367 | + mock_args.max_concurrency = None |
| 368 | + mock_args.lora_modules = None |
| 369 | + mock_args.extra_body = None |
| 370 | + mock_args.percentile_metrics = "ttft,tpot,itl" |
| 371 | + mock_args.metric_percentiles = "99" |
| 372 | + mock_args.goodput = None |
| 373 | + mock_args.ramp_up_strategy = None |
| 374 | + mock_args.ramp_up_start_rps = None |
| 375 | + mock_args.ramp_up_end_rps = None |
| 376 | + mock_args.dataset_name = "EB" |
| 377 | + mock_args.dataset_path = MagicMock() |
| 378 | + mock_args.dataset_split = None |
| 379 | + mock_args.dataset_sample_ratio = 1.0 |
| 380 | + mock_args.dataset_shard_size = None |
| 381 | + mock_args.dataset_shard_rank = None |
| 382 | + mock_args.dataset_shuffle_seed = None |
| 383 | + mock_args.top_p = 0.9 # Add sampling parameters for openai-compatible backend |
| 384 | + mock_args.top_k = 50 |
| 385 | + mock_args.temperature = 0.7 |
| 386 | + mock_args.result_dir = MagicMock() # Mock result_dir |
| 387 | + mock_args.result_filename = MagicMock() # Mock result_filename |
| 388 | + mock_args.save_result = False # Disable actual file saving |
| 389 | + mock_args.save_detailed = False |
| 390 | + mock_args.append_result = False |
| 391 | + mock_parse_args.return_value = mock_args |
| 392 | + |
| 393 | + # Mock get_samples return value |
| 394 | + mock_get_samples.return_value = [ |
| 395 | + SampleRequest(no=1, prompt="test", prompt_len=10, expected_output_len=20, history_QA=[], json_data=None) |
| 396 | + ] |
| 397 | + |
| 398 | + # Setup mock benchmark to raise exception |
| 399 | + mock_benchmark.side_effect = Exception("Benchmark failed") |
| 400 | + |
| 401 | + # Call main_async with args and verify it handles the exception |
| 402 | + with self.assertRaises(Exception): |
| 403 | + await main_async(mock_args) |
| 404 | + |
| 405 | + # Verify get_samples was called |
| 406 | + mock_get_samples.assert_called_once() |
312 | 407 |
|
313 | 408 |
|
314 | 409 | if __name__ == "__main__":
|
|
0 commit comments