Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion tests/unit_tests/test_bucketing.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ def test_generate_prompt_buckets():
ctx_range = [0, 1, 2, 3, 4]
buckets = generate_buckets(bs_range, query_range, ctx_range, True, max_model_len, bs, prompt_bs,
max_num_batched_tokens, block_size, max_blocks)
assert len(buckets) == 40
assert len(buckets) == 17


def test_generate_decode_buckets():
Expand Down
21 changes: 11 additions & 10 deletions vllm_gaudi/extension/bucketing/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -243,22 +243,23 @@ def expand_to_neighbor_buckets(bs_idx, bs_range, query_idx, query_range, max_num
'''
candidates = [(bs_idx, query_idx), (bs_idx + 1, query_idx), (bs_idx, query_idx + 1),
(bs_idx + 1, query_idx + 1)]
valid = bs_range[bs_idx] * query_range[query_idx] <= max_num_batched_tokens
if not valid:
omitted_buckets.add(("bs_range[bs_idx] * query_range[query_idx] <= max_num_batched_tokens",
"-> bs, quesry: ", bs_idx, query_idx))
return {}
valid_candidates = [(b_idx, q_idx) for b_idx, q_idx in candidates
if b_idx < len(bs_range) and q_idx < len(query_range)]
if (b_idx < len(bs_range) and q_idx < len(query_range))]
return {(bs_range[b_idx], query_range[q_idx]) for b_idx, q_idx in valid_candidates}

# filter rules for buckets
# prompt
def not_over_max_model_len(bs, query, ctx):
if not query + ctx * block_size <= max_model_len:
if not bs * (query + ctx * block_size) <= max_model_len:
omitted_buckets.add(
("condition: query + ctx * block_size <= max_model_len", "-> bs, quesry, ctx: ", bs, query, ctx))
return query + ctx * block_size <= max_model_len
("condition: bs * (query + ctx * block_size) <= max_model_len", "-> bs, query, ctx: ", bs, query, ctx))
return bs * (query + ctx * block_size) <= max_model_len

def not_over_max_num_batched_tokens(bs, query, ctx):
if not bs * query <= max_num_batched_tokens:
omitted_buckets.add(
("condition: bs * query <= max_num_batched_tokens", "-> bs, query, ctx: ", bs, query, ctx))
return bs * query <= max_num_batched_tokens

def ctx_not_over_max_ctx_for_merged_prefill(bs, query, ctx):
if not ctx <= max_num_prefill_seqs * math.ceil(
Expand All @@ -285,7 +286,7 @@ def batch_size_smaller_than_blocks(bs, query, ctx):
"prompt": {
# depends only on merged_prefill
True: [ctx_not_over_max_ctx_for_merged_prefill],
False: [not_over_max_model_len],
False: [not_over_max_model_len, not_over_max_num_batched_tokens],
},
"decode": {
# depends only on contiguous PA
Expand Down