diff --git a/sharktank/sharktank/utils/llm_scheduler.py b/sharktank/sharktank/utils/llm_scheduler.py new file mode 100644 index 0000000000..48ee4e9fb6 --- /dev/null +++ b/sharktank/sharktank/utils/llm_scheduler.py @@ -0,0 +1,53 @@ +from dataclasses import dataclass +from sharktank.utils.llm_task import LlmTaskInput +from typing import List, Dict + + +@dataclass +class SchedulerEntry: + task: LlmTaskInput + remaining_count: int + + +class Scheduler: + def __init__(self, batch_size: int): + self._batch_size = batch_size + self._queue: List[LlmTaskInput] = [] + self._schedule_count: Dict[str, SchedulerEntry] = {} + + def schedule_task(self, task: LlmTaskInput, count: int): + self._queue.append(task) + self._schedule_count[task.task_id] = SchedulerEntry( + task=task, remaining_count=count + ) + + def get_scheduled_tasks(self) -> List[LlmTaskInput]: + batch_size = self._batch_size + schedule_tasks = self._queue[:batch_size] + self._queue = self._queue[batch_size:] + + for task in schedule_tasks: + self._schedule_count[task.task_id].remaining_count -= 1 + + return schedule_tasks + + def has_pending_tasks(self) -> bool: + return len(self._queue) > 0 + + def on_task_complete(self, task_id: str, last_token: int) -> bool: + task = self._schedule_count[task_id].task + task.tokens.append(last_token) + task.seq_len = len(task.tokens) + task.start_position = task.seq_len - 1 + + if self._schedule_count[task_id].remaining_count == 0: + del self._schedule_count[task_id] + return True + + self._queue.append(self._schedule_count[task_id].task) + return False + + def remove_task(self, task_id: str): + self._queue = [task for task in self._queue if task.task_id != task_id] + if task_id in self._schedule_count: + del self._schedule_count[task_id] diff --git a/sharktank/sharktank/utils/llm_task.py b/sharktank/sharktank/utils/llm_task.py new file mode 100644 index 0000000000..bc53cf13a0 --- /dev/null +++ b/sharktank/sharktank/utils/llm_task.py @@ -0,0 +1,167 @@ +import dataclasses +import iree.runtime +import numpy + +from abc import ABC, abstractmethod +from typing import Callable, List, Optional, Tuple + +import torch + + +def fill_page_table(bs: int, count: int, page_ids: list[list[int]]) -> numpy.ndarray: + pages = numpy.zeros((bs, count), dtype=numpy.int64) + + for i, ids in enumerate(page_ids): + pages[i, : len(ids)] = ids[:count] + + return pages + + +@dataclasses.dataclass +class LlmTaskInput: + task_id: str + tokens: List[int] + seq_len: int + pages: List[int] + + start_position: Optional[int] = None + + +class LlmTask(ABC): + def __init__( + self, + invocation_fn: Callable[ + [List[numpy.ndarray | iree.runtime.DeviceArray | torch.Tensor]], + Tuple[numpy.ndarray, Optional[numpy.ndarray]], + ], + llm_task_inputs: List[LlmTaskInput], + batch_size: int, + block_stride: int, + ): + self._invocation_fn = invocation_fn + self._task_inputs: List[LlmTaskInput] = llm_task_inputs + self._batch_size = batch_size + self._block_stride = block_stride + + @abstractmethod + def _prepare_args( + self, task_inputs: List[LlmTaskInput], cache + ) -> List[numpy.ndarray | iree.runtime.DeviceArray | torch.Tensor]: + pass + + @abstractmethod + def _process_results( + self, results + ) -> Tuple[numpy.ndarray, Optional[numpy.ndarray]]: + pass + + def run( + self, cache_state: iree.runtime.DeviceArray | torch.Tensor + ) -> Tuple[numpy.ndarray, Optional[numpy.ndarray]]: + task_inputs = self._task_inputs + + args = self._prepare_args(task_inputs, cache_state) + results = self._invocation_fn(*args) + logits, indices = self._process_results(results) + return logits, indices + + +class PrefillTask(LlmTask): + def _prepare_args( + self, + task_inputs: List[LlmTaskInput], + cache: iree.runtime.DeviceArray | torch.Tensor, + ) -> List[numpy.ndarray | iree.runtime.DeviceArray | torch.Tensor]: + block_stride = self._block_stride + bs = self._batch_size + + tokens = [task_input.tokens for task_input in task_inputs] + page_ids = [task_input.pages for task_input in task_inputs] + + max_len = max(len(input_tokens) for input_tokens in tokens) + blocks = int(numpy.ceil(max_len / block_stride)) + blocked_len = blocks * block_stride + + tokens_ = numpy.zeros((bs, blocked_len), dtype=numpy.int64) + lens_ = numpy.ones((bs,), dtype=numpy.int64) + + for i, input_tokens in enumerate(tokens): + tokens_[i, : len(input_tokens)] = input_tokens + lens_[i] = len(input_tokens) + + pages_ = fill_page_table(bs, blocks, page_ids) + args = [ + tokens_, + lens_, + pages_, + cache, + ] + return args + + def _process_results( + self, results + ) -> Tuple[numpy.ndarray, Optional[numpy.ndarray]]: + if isinstance(results, tuple): + logits, indices = results + logits = numpy.asarray(logits) + indices = numpy.asarray(indices) + else: + logits = numpy.asarray(results) + indices = None + return logits, indices + + +class DecodeTask(LlmTask): + def _prepare_args( + self, + task_inputs: List[LlmTaskInput], + cache: iree.runtime.DeviceArray | torch.Tensor, + ) -> List[numpy.ndarray | iree.runtime.DeviceArray | torch.Tensor]: + assert all( + task_input.start_position is not None for task_input in task_inputs + ), "`start_positions` is a required argument for `decode`" + + block_stride = self._block_stride + decode_bs = self._batch_size + bs = len(task_inputs) + + tokens = [task_input.tokens[-1] for task_input in task_inputs] + page_ids = [task_input.pages for task_input in task_inputs] + start_positions = [task_input.start_position for task_input in task_inputs] + + max_len = max(start_positions) + 1 + blocks = int(numpy.ceil(max_len / block_stride)) + + tokens_ = numpy.zeros((decode_bs, 1), dtype=numpy.int64) + lens_ = numpy.ones((decode_bs,), dtype=numpy.int64) + pos_ = numpy.ones((decode_bs,), dtype=numpy.int64) + + for i in range(bs): + tokens_[i, 0] = tokens[i] + lens_[i] = start_positions[i] + 1 + pos_[i] = start_positions[i] + + pages_ = fill_page_table(decode_bs, blocks, page_ids) + + args = [ + tokens_, + lens_, + pos_, + pages_, + cache, + ] + return args + + def _process_results( + self, results + ) -> Tuple[numpy.ndarray, Optional[numpy.ndarray]]: + if isinstance(results, tuple): + logits, indices = results + else: + k = 8 + logits = torch.asarray(numpy.asarray(results)) + logits, indices = torch.topk(logits, k) + + logits = numpy.asarray(logits) + indices = numpy.asarray(indices) + return logits, indices diff --git a/sharktank/sharktank/utils/llm_utils.py b/sharktank/sharktank/utils/llm_utils.py index 80495a51c7..1b4122ab21 100644 --- a/sharktank/sharktank/utils/llm_utils.py +++ b/sharktank/sharktank/utils/llm_utils.py @@ -36,6 +36,8 @@ from sharktank.models.llm import PagedLlmModelV1 from sharktank.types import Dataset, Theta from sharktank.utils.attention import * +from sharktank.utils.llm_scheduler import Scheduler +from sharktank.utils.llm_task import LlmTaskInput, PrefillTask, DecodeTask from typing import Callable, List, Optional, Tuple np_dtype_to_torch_dtype = { @@ -268,169 +270,6 @@ def allocate(self, *shape, dtype, device_index: int): return torch.zeros(*shape, dtype=dtype, device=self._device) -def fill_page_table(bs: int, count: int, page_ids: list[list[int]]) -> numpy.ndarray: - pages = numpy.zeros((bs, count), dtype=numpy.int64) - - for i, ids in enumerate(page_ids): - pages[i, : len(ids)] = ids[:count] - - return pages - - -class LlmTaskType(Enum): - PREFILL = auto() - DECODE = auto() - - -@dataclasses.dataclass -class LlmTaskInput: - tokens: List[int] - seq_len: int - pages: List[int] - - start_position: Optional[int] = None - - -class LlmTask(ABC): - def __init__( - self, - invocation_fn: Callable[ - [List[numpy.ndarray | iree.runtime.DeviceArray | torch.Tensor]], - Tuple[numpy.ndarray, Optional[numpy.ndarray]], - ], - llm_task_inputs: List[LlmTaskInput], - batch_size: int, - block_stride: int, - ): - self._invocation_fn = invocation_fn - self._task_inputs: List[LlmTaskInput] = llm_task_inputs - self._batch_size = batch_size - self._block_stride = block_stride - - @abstractmethod - def _prepare_args( - self, task_inputs: List[LlmTaskInput], cache - ) -> List[numpy.ndarray | iree.runtime.DeviceArray | torch.Tensor]: - pass - - @abstractmethod - def _process_results( - self, results - ) -> Tuple[numpy.ndarray, Optional[numpy.ndarray]]: - pass - - def run( - self, cache_state: iree.runtime.DeviceArray | torch.Tensor - ) -> Tuple[numpy.ndarray, Optional[numpy.ndarray]]: - task_inputs = self._task_inputs - - args = self._prepare_args(task_inputs, cache_state) - results = self._invocation_fn(*args) - logits, indices = self._process_results(results) - return logits, indices - - -class PrefillTask(LlmTask): - def _prepare_args( - self, - task_inputs: List[LlmTaskInput], - cache: iree.runtime.DeviceArray | torch.Tensor, - ) -> List[numpy.ndarray | iree.runtime.DeviceArray | torch.Tensor]: - block_stride = self._block_stride - bs = self._batch_size - - tokens = [task_input.tokens for task_input in task_inputs] - page_ids = [task_input.pages for task_input in task_inputs] - - max_len = max(len(input_tokens) for input_tokens in tokens) - blocks = int(numpy.ceil(max_len / block_stride)) - blocked_len = blocks * block_stride - - tokens_ = numpy.zeros((bs, blocked_len), dtype=numpy.int64) - lens_ = numpy.ones((bs,), dtype=numpy.int64) - - for i, input_tokens in enumerate(tokens): - tokens_[i, : len(input_tokens)] = input_tokens - lens_[i] = len(input_tokens) - - pages_ = fill_page_table(bs, blocks, page_ids) - args = [ - tokens_, - lens_, - pages_, - cache, - ] - return args - - def _process_results( - self, results - ) -> Tuple[numpy.ndarray, Optional[numpy.ndarray]]: - if isinstance(results, tuple): - logits, indices = results - logits = numpy.asarray(logits) - indices = numpy.asarray(indices) - else: - logits = numpy.asarray(results) - indices = None - return logits, indices - - -class DecodeTask(LlmTask): - def _prepare_args( - self, - task_inputs: List[LlmTaskInput], - cache: iree.runtime.DeviceArray | torch.Tensor, - ) -> List[numpy.ndarray | iree.runtime.DeviceArray | torch.Tensor]: - assert all( - task_input.start_position is not None for task_input in task_inputs - ), "`start_positions` is a required argument for `decode`" - - block_stride = self._block_stride - decode_bs = self._batch_size - bs = len(task_inputs) - - tokens = [task_input.tokens[0] for task_input in task_inputs] - page_ids = [task_input.pages for task_input in task_inputs] - start_positions = [task_input.start_position for task_input in task_inputs] - - max_len = max(start_positions) + 1 - blocks = int(numpy.ceil(max_len / block_stride)) - - tokens_ = numpy.zeros((decode_bs, 1), dtype=numpy.int64) - lens_ = numpy.ones((decode_bs,), dtype=numpy.int64) - pos_ = numpy.ones((decode_bs,), dtype=numpy.int64) - - for i in range(bs): - tokens_[i, 0] = tokens[i] - lens_[i] = start_positions[i] + 1 - pos_[i] = start_positions[i] - - pages_ = fill_page_table(decode_bs, blocks, page_ids) - - args = [ - tokens_, - lens_, - pos_, - pages_, - cache, - ] - return args - - def _process_results( - self, results - ) -> Tuple[numpy.ndarray, Optional[numpy.ndarray]]: - if isinstance(results, tuple): - logits, indices = results - else: - k = 8 - logits = torch.asarray(numpy.asarray(results)) - logits, indices = torch.topk(logits, k) - - logits = numpy.asarray(logits) - indices = numpy.asarray(indices) - return logits, indices - - class LlmAllocator: def __init__(self, page_count, block_stride): self._pages = list(range(1, page_count)) @@ -470,7 +309,12 @@ def __init__( self._allocator = LlmAllocator(page_count=page_count, block_stride=block_stride) - self._allocator = LlmAllocator(page_count=page_count, block_stride=block_stride) + self._prefill_scheduler = Scheduler( + batch_size=self._prefill_bs, + ) + self._decode_scheduler = Scheduler( + batch_size=self._decode_bs, + ) self._cache = [ instance.allocate( @@ -490,13 +334,18 @@ def allocate( def free(self, pages: list[int]): self._allocator.free(pages) - def prefill(self, requests: list[list[int]], page_ids: list[list[int]]): + def prefill( + self, + requests: list[list[int]], + page_ids: list[list[int]], + ): assert len(requests) == len(page_ids) task_inputs = [] for i, request in enumerate(requests): task_inputs.append( LlmTaskInput( + task_id=f"req_{i}", tokens=request, seq_len=len(request), pages=page_ids[i], @@ -513,7 +362,10 @@ def prefill(self, requests: list[list[int]], page_ids: list[list[int]]): return logits, indices def decode( - self, tokens: list[int], positions: list[int], page_ids: list[list[int]] + self, + tokens: list[int], + positions: list[int], + page_ids: list[list[int]], ): assert len(tokens) == len(positions) assert len(tokens) == len(page_ids) @@ -522,6 +374,7 @@ def decode( for i, token in enumerate(tokens): task_inputs.append( LlmTaskInput( + task_id=f"req_{i}", tokens=[token], seq_len=positions[i] + 1, start_position=positions[i], @@ -538,6 +391,83 @@ def decode( logits, indices = decode_task.run(*self._cache) return logits, indices + def submit_tasks( + self, + requests: list[list[int]], + page_ids: list[list[int]], + steps: int, + ): + assert len(requests) == len(page_ids) + + for i, request in enumerate(requests): + task_input = LlmTaskInput( + task_id=f"req_{i}", + tokens=request, + seq_len=len(request), + pages=page_ids[i], + ) + # Submit prefill task + self._prefill_scheduler.schedule_task(task_input, 1) + # Submit decode tasks + self._decode_scheduler.schedule_task(task_input, steps - 1) + + def run( + self, + selection_fn: Callable[ + [numpy.ndarray, Optional[numpy.ndarray], list[int]], + list[int], + ], + eos_token: int, + steps: int, + ): + selections_map = {} + input_keys = [] + task_inputs = self._prefill_scheduler.get_scheduled_tasks() + for task in task_inputs: + selections_map[task.task_id] = [] + input_keys.append(task.task_id) + + # Run prefill + prefill_task = PrefillTask( + invocation_fn=self._instance.prefill, + llm_task_inputs=task_inputs, + batch_size=self._prefill_bs, + block_stride=self._block_stride, + ) + logits, indices = prefill_task.run(*self._cache) + last = selection_fn( + logits, + indices, + [task_inputs.seq_len - 1 for task_inputs in task_inputs], + ) + for task_input, token in zip(task_inputs, last): + selections_map[task_input.task_id].append(token) + self._prefill_scheduler.on_task_complete(task_input.task_id, token) + + # Run decode steps + for _ in range(steps - 1): + if not self._decode_scheduler.has_pending_tasks(): + break + task_inputs = self._decode_scheduler.get_scheduled_tasks() + decode_task = DecodeTask( + invocation_fn=self._instance.decode, + llm_task_inputs=task_inputs, + batch_size=self._decode_bs, + block_stride=self._block_stride, + ) + logits, indices = decode_task.run(*self._cache) + last = selection_fn(logits, indices, [0] * len(task_inputs)) + for task_input, token in zip(task_inputs, last): + selections_map[task_input.task_id].append(token) + self._decode_scheduler.on_task_complete( + task_input.task_id, + token, + ) + if token == eos_token: + self._decode_scheduler.remove_task(task_input.task_id) + + return [selections_map[key] for key in input_keys] + class LlmDecoder: def __init__(self, batch): @@ -557,41 +487,18 @@ def _greedy_select(self, logits, indices, positions): def greedy_decode( self, requests: list[list[int]], steps: int, eos: int | None = None ): - selections = [] - positions = [len(request) - 1 for request in requests] - page_ids = [ self._batch.allocate(token_count=len(req) + steps) for req in requests ] - logits, indices = self._batch.prefill(requests, page_ids=page_ids) - last = self._greedy_select(logits, indices, positions) - done = [False for _ in range(len(requests))] - done = [d or t == eos for d, t in zip(done, last)] - selections.append(last) + self._batch.submit_tasks(requests=requests, page_ids=page_ids, steps=steps) + selections = self._batch.run( + selection_fn=self._greedy_select, + eos_token=eos, + steps=steps, + ) - for _ in range(steps - 1): - if all(done): - break - positions = [p + 1 for p in positions] - logits, indices = self._batch.decode( - tokens=last, positions=positions, page_ids=page_ids - ) - last = self._greedy_select(logits, indices, [0] * len(requests)) - done = [d or t == eos for d, t in zip(done, last)] - selections.append(last) - - results = [[] for i in range(len(selections[0]))] - for select in selections: - for j, token in enumerate(select): - results[j].append(token.item()) - - eos_pos = [[i for i, t in enumerate(result) if t == eos] for result in results] - results = [ - result[: pos[0] + 1] if len(pos) > 0 else result - for result, pos in zip(results, eos_pos) - ] - return results + return selections class LlmBencher: