Skip to content

Commit 69e9f6d

Browse files
authored
[fix]: Skip prompt length checking for generation only requests (#6146)
Signed-off-by: Pengyun Lin <[email protected]>
1 parent 66030ef commit 69e9f6d

File tree

3 files changed

+21
-9
lines changed

3 files changed

+21
-9
lines changed

tensorrt_llm/disaggregated_params.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,10 +6,10 @@
66

77
@dataclass(slots=True, kw_only=True)
88
class DisaggregatedParams:
9-
"""Disaggregated seving parameters.
9+
"""Disaggregated serving parameters.
1010
1111
Args:
12-
request_type (str): The type of request ("context_only" or "generation_only")
12+
request_type (str): The type of request ("context_only" | "generation_only" | "context_and_generation")
1313
first_gen_tokens (List[int]): The first tokens of the generation request
1414
ctx_request_id (int): The context request id
1515
opaque_state(bytes): Any additional state needing to be exchanged between context and gen instances

tensorrt_llm/llmapi/llm.py

Lines changed: 10 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -334,9 +334,9 @@ def generate_async(
334334
# With pytorch backend, py_executor has logic to handle max_tokens of 1,
335335
# so set to 1 to avoid allocating unnecessary KV cache blocks for single request
336336
# TODO: Also support for trt backend
337-
if (disaggregated_params is not None
338-
and disaggregated_params.request_type == "context_only"
339-
and not self._on_trt_backend):
337+
is_ctx_only = disaggregated_params is not None and disaggregated_params.request_type == "context_only"
338+
is_gen_only = disaggregated_params is not None and disaggregated_params.request_type == "generation_only"
339+
if is_ctx_only and not self._on_trt_backend:
340340
sampling_params.max_tokens = 1
341341

342342
inputs = prompt_inputs(inputs)
@@ -401,7 +401,8 @@ def generate_async(
401401
self._check_arguments(
402402
len(prompt_token_ids),
403403
len(query_token_ids) if query_token_ids is not None else 0,
404-
sampling_params)
404+
sampling_params,
405+
is_gen_only=is_gen_only)
405406
if _postproc_params:
406407
_postproc_params.postproc_args.num_prompt_tokens = len(
407408
prompt_token_ids)
@@ -529,7 +530,8 @@ def _prepare_sampling_params(
529530
return sampling_params
530531

531532
def _check_arguments(self, prompt_len: int, query_len: int,
532-
sampling_params: SamplingParams) -> None:
533+
sampling_params: SamplingParams,
534+
is_gen_only: bool) -> None:
533535

534536
if self.args.backend in ["pytorch", "_autodeploy"]:
535537
# TODO: remove these checks after PyTorch backend
@@ -543,11 +545,12 @@ def _check_arguments(self, prompt_len: int, query_len: int,
543545
f"PyTorch backend currently only supports `logprobs=1`. Received `logprobs={sampling_params.logprobs}` (Top{sampling_params.logprobs} logprobs). Please set `logprobs=1` in `sampling_params` instead."
544546
)
545547
# Check prompt length and query length against max_num_tokens to filter illegal requests.
546-
if self.args.backend == "pytorch" and not self.args.enable_chunked_prefill:
548+
# Skip check for gen-only requests
549+
if self.args.backend == "pytorch" and not self.args.enable_chunked_prefill and not is_gen_only:
547550
max_num_tokens = self.args.max_num_tokens
548551
if max_num_tokens and prompt_len / self.args.parallel_config.cp_size + query_len > max_num_tokens:
549552
raise ValueError(
550-
f"The sum of prompt length ({prompt_len/self.args.parallel_config.cp_size}), query length ({query_len}) and max_tokens ({sampling_params.max_tokens}) should not exceed "
553+
f"The sum of prompt length ({prompt_len/self.args.parallel_config.cp_size}), query length ({query_len}) should not exceed "
551554
f"max_num_tokens ({max_num_tokens})")
552555
return
553556

tensorrt_llm/llmapi/llm_args.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1357,6 +1357,15 @@ def set_runtime_knobs_from_build_config(self):
13571357

13581358
return self
13591359

1360+
@model_validator(mode="after")
1361+
def validate_runtime_args(self):
1362+
if self.max_batch_size is not None and self.max_num_tokens is not None:
1363+
if self.max_batch_size > self.max_num_tokens:
1364+
logger.warning(
1365+
f"max_batch_size [{self.max_batch_size}] should be less than or equal to max_num_tokens [{self.max_num_tokens}]"
1366+
)
1367+
return self
1368+
13601369
@model_validator(mode="after")
13611370
def validate_build_config_with_runtime_params(self):
13621371
# Note: max_batch_size and max_num_tokens in LlmArgs are for runtime,

0 commit comments

Comments
 (0)