Skip to content

Commit c6fbb11

Browse files
fix tests
Signed-off-by: Lucas Wilkinson <[email protected]>
1 parent ce84b41 commit c6fbb11

File tree

2 files changed

+10
-4
lines changed

2 files changed

+10
-4
lines changed

tests/v1/attention/test_attention_splitting.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -294,10 +294,14 @@ def test_prefill_split_across_ubatches(
294294
qsl_np = common.query_start_loc_cpu.numpy()
295295
num_tokens = common.num_actual_tokens
296296

297-
ubatch_slices = maybe_create_ubatch_slices(
298-
True, num_scheduled_tokens, num_tokens, batch_spec.batch_size
297+
ubatch_slices, _ = maybe_create_ubatch_slices(
298+
True,
299+
num_scheduled_tokens,
300+
num_tokens,
301+
batch_spec.batch_size,
302+
split_point=split_point,
299303
)
300-
assert len(ubatch_slices) == 2
304+
assert ubatch_slices is not None and len(ubatch_slices) == 2
301305

302306
first_meta = _make_metadata_with_slice(ubatch_slices[0], common)
303307
second_meta = _make_metadata_with_slice(ubatch_slices[1], common)

vllm/v1/worker/ubatch_utils.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -65,11 +65,13 @@ def maybe_create_ubatch_slices(
6565
num_scheduled_tokens: np.ndarray,
6666
num_tokens_padded: int,
6767
num_reqs_padded: int,
68+
split_point: int | None = None,
6869
) -> tuple[UBatchSlices | None, UBatchSlices | None]:
6970
if not should_ubatch:
7071
return None, None
7172

72-
split_point = int(num_tokens_padded) // 2
73+
if split_point is None:
74+
split_point = int(num_tokens_padded) // 2
7375

7476
# TODO(lucas): Refactor the gpu_model_runner.py so we can pass
7577
# in cu_num_tokens directly (i.e. query_start_loc)

0 commit comments

Comments
 (0)