-
Notifications
You must be signed in to change notification settings - Fork 67
[Sharktank] Llm Task Scheduler #2418
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
ac04cd4
012f25f
6c894f4
e405218
cf3aeb4
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,53 @@ | ||
from dataclasses import dataclass | ||
from sharktank.utils.llm_task import LlmTaskInput | ||
from typing import List, Dict | ||
|
||
|
||
@dataclass | ||
class SchedulerEntry: | ||
task: LlmTaskInput | ||
remaining_count: int | ||
|
||
|
||
class Scheduler: | ||
def __init__(self, batch_size: int): | ||
self._batch_size = batch_size | ||
self._queue: List[LlmTaskInput] = [] | ||
self._schedule_count: Dict[str, SchedulerEntry] = {} | ||
|
||
def schedule_task(self, task: LlmTaskInput, count: int): | ||
self._queue.append(task) | ||
self._schedule_count[task.task_id] = SchedulerEntry( | ||
task=task, remaining_count=count | ||
) | ||
|
||
def get_scheduled_tasks(self) -> List[LlmTaskInput]: | ||
batch_size = self._batch_size | ||
schedule_tasks = self._queue[:batch_size] | ||
self._queue = self._queue[batch_size:] | ||
|
||
for task in schedule_tasks: | ||
self._schedule_count[task.task_id].remaining_count -= 1 | ||
|
||
return schedule_tasks | ||
|
||
def has_pending_tasks(self) -> bool: | ||
return len(self._queue) > 0 | ||
|
||
def on_task_complete(self, task_id: str, last_token: int) -> bool: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I am pretty sure this is a bad section. Calling a callback on completion is likely going to product bad results. |
||
task = self._schedule_count[task_id].task | ||
task.tokens.append(last_token) | ||
task.seq_len = len(task.tokens) | ||
task.start_position = task.seq_len - 1 | ||
|
||
if self._schedule_count[task_id].remaining_count == 0: | ||
del self._schedule_count[task_id] | ||
return True | ||
|
||
self._queue.append(self._schedule_count[task_id].task) | ||
return False | ||
|
||
def remove_task(self, task_id: str): | ||
self._queue = [task for task in self._queue if task.task_id != task_id] | ||
if task_id in self._schedule_count: | ||
del self._schedule_count[task_id] |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,167 @@ | ||
import dataclasses | ||
import iree.runtime | ||
import numpy | ||
|
||
from abc import ABC, abstractmethod | ||
from typing import Callable, List, Optional, Tuple | ||
|
||
import torch | ||
|
||
|
||
def fill_page_table(bs: int, count: int, page_ids: list[list[int]]) -> numpy.ndarray: | ||
pages = numpy.zeros((bs, count), dtype=numpy.int64) | ||
|
||
for i, ids in enumerate(page_ids): | ||
pages[i, : len(ids)] = ids[:count] | ||
|
||
return pages | ||
|
||
|
||
@dataclasses.dataclass | ||
class LlmTaskInput: | ||
task_id: str | ||
tokens: List[int] | ||
seq_len: int | ||
pages: List[int] | ||
|
||
start_position: Optional[int] = None | ||
|
||
|
||
class LlmTask(ABC): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I would move this to the scheduler and generalize it. Most of the basics are pretty use case vague and it would be easy to write some tests on the scheduler + task |
||
def __init__( | ||
self, | ||
invocation_fn: Callable[ | ||
[List[numpy.ndarray | iree.runtime.DeviceArray | torch.Tensor]], | ||
Tuple[numpy.ndarray, Optional[numpy.ndarray]], | ||
], | ||
llm_task_inputs: List[LlmTaskInput], | ||
batch_size: int, | ||
block_stride: int, | ||
): | ||
self._invocation_fn = invocation_fn | ||
self._task_inputs: List[LlmTaskInput] = llm_task_inputs | ||
self._batch_size = batch_size | ||
self._block_stride = block_stride | ||
|
||
@abstractmethod | ||
def _prepare_args( | ||
self, task_inputs: List[LlmTaskInput], cache | ||
) -> List[numpy.ndarray | iree.runtime.DeviceArray | torch.Tensor]: | ||
pass | ||
|
||
@abstractmethod | ||
def _process_results( | ||
self, results | ||
) -> Tuple[numpy.ndarray, Optional[numpy.ndarray]]: | ||
pass | ||
|
||
def run( | ||
self, cache_state: iree.runtime.DeviceArray | torch.Tensor | ||
) -> Tuple[numpy.ndarray, Optional[numpy.ndarray]]: | ||
task_inputs = self._task_inputs | ||
|
||
args = self._prepare_args(task_inputs, cache_state) | ||
results = self._invocation_fn(*args) | ||
logits, indices = self._process_results(results) | ||
return logits, indices | ||
|
||
|
||
class PrefillTask(LlmTask): | ||
def _prepare_args( | ||
self, | ||
task_inputs: List[LlmTaskInput], | ||
cache: iree.runtime.DeviceArray | torch.Tensor, | ||
) -> List[numpy.ndarray | iree.runtime.DeviceArray | torch.Tensor]: | ||
block_stride = self._block_stride | ||
bs = self._batch_size | ||
|
||
tokens = [task_input.tokens for task_input in task_inputs] | ||
page_ids = [task_input.pages for task_input in task_inputs] | ||
|
||
max_len = max(len(input_tokens) for input_tokens in tokens) | ||
blocks = int(numpy.ceil(max_len / block_stride)) | ||
blocked_len = blocks * block_stride | ||
|
||
tokens_ = numpy.zeros((bs, blocked_len), dtype=numpy.int64) | ||
lens_ = numpy.ones((bs,), dtype=numpy.int64) | ||
|
||
for i, input_tokens in enumerate(tokens): | ||
tokens_[i, : len(input_tokens)] = input_tokens | ||
lens_[i] = len(input_tokens) | ||
|
||
pages_ = fill_page_table(bs, blocks, page_ids) | ||
args = [ | ||
tokens_, | ||
lens_, | ||
pages_, | ||
cache, | ||
] | ||
return args | ||
|
||
def _process_results( | ||
self, results | ||
) -> Tuple[numpy.ndarray, Optional[numpy.ndarray]]: | ||
if isinstance(results, tuple): | ||
logits, indices = results | ||
logits = numpy.asarray(logits) | ||
indices = numpy.asarray(indices) | ||
else: | ||
logits = numpy.asarray(results) | ||
indices = None | ||
return logits, indices | ||
|
||
|
||
class DecodeTask(LlmTask): | ||
def _prepare_args( | ||
self, | ||
task_inputs: List[LlmTaskInput], | ||
cache: iree.runtime.DeviceArray | torch.Tensor, | ||
) -> List[numpy.ndarray | iree.runtime.DeviceArray | torch.Tensor]: | ||
assert all( | ||
task_input.start_position is not None for task_input in task_inputs | ||
), "`start_positions` is a required argument for `decode`" | ||
|
||
block_stride = self._block_stride | ||
decode_bs = self._batch_size | ||
bs = len(task_inputs) | ||
|
||
tokens = [task_input.tokens[-1] for task_input in task_inputs] | ||
page_ids = [task_input.pages for task_input in task_inputs] | ||
start_positions = [task_input.start_position for task_input in task_inputs] | ||
|
||
max_len = max(start_positions) + 1 | ||
blocks = int(numpy.ceil(max_len / block_stride)) | ||
|
||
tokens_ = numpy.zeros((decode_bs, 1), dtype=numpy.int64) | ||
lens_ = numpy.ones((decode_bs,), dtype=numpy.int64) | ||
pos_ = numpy.ones((decode_bs,), dtype=numpy.int64) | ||
|
||
for i in range(bs): | ||
tokens_[i, 0] = tokens[i] | ||
lens_[i] = start_positions[i] + 1 | ||
pos_[i] = start_positions[i] | ||
|
||
pages_ = fill_page_table(decode_bs, blocks, page_ids) | ||
|
||
args = [ | ||
tokens_, | ||
lens_, | ||
pos_, | ||
pages_, | ||
cache, | ||
] | ||
return args | ||
|
||
def _process_results( | ||
self, results | ||
) -> Tuple[numpy.ndarray, Optional[numpy.ndarray]]: | ||
if isinstance(results, tuple): | ||
logits, indices = results | ||
else: | ||
k = 8 | ||
logits = torch.asarray(numpy.asarray(results)) | ||
logits, indices = torch.topk(logits, k) | ||
|
||
logits = numpy.asarray(logits) | ||
indices = numpy.asarray(indices) | ||
return logits, indices |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Make enough parts in scheduler and write some tests for it. I think you will understand the issues with your API if you attempt to use it standalone vs retrofitting it into the old process.