diff --git a/examples/offline_inference/long_context.py b/examples/offline_inference/long_context.py index 819b43e03..087d8e082 100644 --- a/examples/offline_inference/long_context.py +++ b/examples/offline_inference/long_context.py @@ -121,7 +121,7 @@ def round_up(t): tokens_to_generate = [ - args.max_model_len + 1 - round_up(prompt_len) for prompt_len in prompt_lens + args.max_model_len - round_up(prompt_len) for prompt_len in prompt_lens ] sampling_params = [ diff --git a/vllm_spyre/platform.py b/vllm_spyre/platform.py index 685dd8440..dd7100a9d 100644 --- a/vllm_spyre/platform.py +++ b/vllm_spyre/platform.py @@ -340,8 +340,7 @@ def validate_request( # ceil division to pad to next block boundary prompt_padding_len = math.ceil( prompt_len / cls._block_size) * cls._block_size - # we have to account for the token generated during prefill (-1) - if (prompt_padding_len + max_tokens - 1 + if (prompt_padding_len + max_tokens > cls._config.scheduler_config.max_model_len): raise ValueError( "Could not add request: prompt length is "