|
14 | 14 | from typing import AsyncIterator, Dict, Optional, Tuple
|
15 | 15 |
|
16 | 16 | import grpc
|
| 17 | +from google.protobuf.json_format import MessageToDict |
17 | 18 | from grpc_reflection.v1alpha import reflection
|
18 | 19 |
|
19 | 20 | from sglang.srt.disaggregation.utils import FAKE_BOOTSTRAP_HOST, DisaggregationMode
|
@@ -483,28 +484,52 @@ def _convert_sampling_params(
|
483 | 484 | elif grpc_params.HasField("structural_tag"):
|
484 | 485 | structural_tag = grpc_params.structural_tag
|
485 | 486 |
|
| 487 | + # Handle optional parameters conversion |
| 488 | + custom_params = ( |
| 489 | + MessageToDict(grpc_params.custom_params) |
| 490 | + if grpc_params.HasField("custom_params") |
| 491 | + else None |
| 492 | + ) |
| 493 | + max_new_tokens = ( |
| 494 | + grpc_params.max_new_tokens |
| 495 | + if grpc_params.HasField("max_new_tokens") |
| 496 | + else None |
| 497 | + ) |
| 498 | + stream_interval = ( |
| 499 | + grpc_params.stream_interval |
| 500 | + if grpc_params.HasField("stream_interval") |
| 501 | + else None |
| 502 | + ) |
| 503 | + logit_bias = dict(grpc_params.logit_bias) if grpc_params.logit_bias else None |
| 504 | + stop = list(grpc_params.stop) if grpc_params.stop else None |
| 505 | + stop_token_ids = ( |
| 506 | + list(grpc_params.stop_token_ids) if grpc_params.stop_token_ids else None |
| 507 | + ) |
| 508 | + |
486 | 509 | return SGLSamplingParams(
|
487 |
| - temperature=grpc_params.temperature or 1.0, |
488 |
| - top_p=grpc_params.top_p or 1.0, |
489 |
| - top_k=grpc_params.top_k or -1, |
490 |
| - min_p=grpc_params.min_p or 0.0, |
491 |
| - frequency_penalty=grpc_params.frequency_penalty or 0.0, |
492 |
| - presence_penalty=grpc_params.presence_penalty or 0.0, |
493 |
| - repetition_penalty=grpc_params.repetition_penalty or 1.0, |
494 |
| - max_new_tokens=grpc_params.max_new_tokens or 128, |
495 |
| - min_new_tokens=grpc_params.min_new_tokens or 0, |
496 |
| - stop=list(grpc_params.stop) if grpc_params.stop else [], |
497 |
| - stop_token_ids=( |
498 |
| - list(grpc_params.stop_token_ids) if grpc_params.stop_token_ids else [] |
499 |
| - ), |
| 510 | + temperature=grpc_params.temperature, |
| 511 | + top_p=grpc_params.top_p, |
| 512 | + top_k=grpc_params.top_k, |
| 513 | + min_p=grpc_params.min_p, |
| 514 | + frequency_penalty=grpc_params.frequency_penalty, |
| 515 | + presence_penalty=grpc_params.presence_penalty, |
| 516 | + repetition_penalty=grpc_params.repetition_penalty, |
| 517 | + max_new_tokens=max_new_tokens, |
| 518 | + min_new_tokens=grpc_params.min_new_tokens, |
| 519 | + stop=stop, |
| 520 | + stop_token_ids=stop_token_ids, |
500 | 521 | skip_special_tokens=grpc_params.skip_special_tokens,
|
501 | 522 | spaces_between_special_tokens=grpc_params.spaces_between_special_tokens,
|
| 523 | + no_stop_trim=grpc_params.no_stop_trim, |
502 | 524 | regex=regex,
|
503 | 525 | json_schema=json_schema,
|
504 | 526 | ebnf=ebnf_grammar,
|
505 | 527 | structural_tag=structural_tag,
|
506 |
| - n=grpc_params.n or 1, |
| 528 | + n=grpc_params.n, |
507 | 529 | ignore_eos=grpc_params.ignore_eos,
|
| 530 | + stream_interval=stream_interval, |
| 531 | + logit_bias=logit_bias, |
| 532 | + custom_params=custom_params, |
508 | 533 | )
|
509 | 534 |
|
510 | 535 | def _convert_output_logprobs_to_proto(
|
|
0 commit comments