diff --git a/tests/unit_tests/test_bucketing.py b/tests/unit_tests/test_bucketing.py index 1a4b2762..76d96e96 100644 --- a/tests/unit_tests/test_bucketing.py +++ b/tests/unit_tests/test_bucketing.py @@ -50,7 +50,7 @@ def test_generate_prompt_buckets(): ctx_range = [0, 1, 2, 3, 4] buckets = generate_buckets(bs_range, query_range, ctx_range, True, max_model_len, bs, prompt_bs, max_num_batched_tokens, block_size, max_blocks) - assert len(buckets) == 40 + assert len(buckets) == 17 def test_generate_decode_buckets(): diff --git a/vllm_gaudi/extension/bucketing/common.py b/vllm_gaudi/extension/bucketing/common.py index 13565fa9..2f6b6ded 100644 --- a/vllm_gaudi/extension/bucketing/common.py +++ b/vllm_gaudi/extension/bucketing/common.py @@ -243,22 +243,23 @@ def expand_to_neighbor_buckets(bs_idx, bs_range, query_idx, query_range, max_num ''' candidates = [(bs_idx, query_idx), (bs_idx + 1, query_idx), (bs_idx, query_idx + 1), (bs_idx + 1, query_idx + 1)] - valid = bs_range[bs_idx] * query_range[query_idx] <= max_num_batched_tokens - if not valid: - omitted_buckets.add(("bs_range[bs_idx] * query_range[query_idx] <= max_num_batched_tokens", - "-> bs, quesry: ", bs_idx, query_idx)) - return {} valid_candidates = [(b_idx, q_idx) for b_idx, q_idx in candidates - if b_idx < len(bs_range) and q_idx < len(query_range)] + if (b_idx < len(bs_range) and q_idx < len(query_range))] return {(bs_range[b_idx], query_range[q_idx]) for b_idx, q_idx in valid_candidates} # filter rules for buckets # prompt def not_over_max_model_len(bs, query, ctx): - if not query + ctx * block_size <= max_model_len: + if not bs * (query + ctx * block_size) <= max_model_len: omitted_buckets.add( - ("condition: query + ctx * block_size <= max_model_len", "-> bs, quesry, ctx: ", bs, query, ctx)) - return query + ctx * block_size <= max_model_len + ("condition: bs * (query + ctx * block_size) <= max_model_len", "-> bs, query, ctx: ", bs, query, ctx)) + return bs * (query + ctx * block_size) <= max_model_len + + def not_over_max_num_batched_tokens(bs, query, ctx): + if not bs * query <= max_num_batched_tokens: + omitted_buckets.add( + ("condition: bs * query <= max_num_batched_tokens", "-> bs, query, ctx: ", bs, query, ctx)) + return bs * query <= max_num_batched_tokens def ctx_not_over_max_ctx_for_merged_prefill(bs, query, ctx): if not ctx <= max_num_prefill_seqs * math.ceil( @@ -285,7 +286,7 @@ def batch_size_smaller_than_blocks(bs, query, ctx): "prompt": { # depends only on merged_prefill True: [ctx_not_over_max_ctx_for_merged_prefill], - False: [not_over_max_model_len], + False: [not_over_max_model_len, not_over_max_num_batched_tokens], }, "decode": { # depends only on contiguous PA