Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
53 changes: 53 additions & 0 deletions sharktank/sharktank/utils/llm_scheduler.py
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Make enough parts in scheduler and write some tests for it. I think you will understand the issues with your API if you attempt to use it standalone vs retrofitting it into the old process.

Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
from dataclasses import dataclass
from sharktank.utils.llm_task import LlmTaskInput
from typing import List, Dict


@dataclass
class SchedulerEntry:
task: LlmTaskInput
remaining_count: int


class Scheduler:
def __init__(self, batch_size: int):
self._batch_size = batch_size
self._queue: List[LlmTaskInput] = []
self._schedule_count: Dict[str, SchedulerEntry] = {}

def schedule_task(self, task: LlmTaskInput, count: int):
self._queue.append(task)
self._schedule_count[task.task_id] = SchedulerEntry(
task=task, remaining_count=count
)

def get_scheduled_tasks(self) -> List[LlmTaskInput]:
batch_size = self._batch_size
schedule_tasks = self._queue[:batch_size]
self._queue = self._queue[batch_size:]

for task in schedule_tasks:
self._schedule_count[task.task_id].remaining_count -= 1

return schedule_tasks

def has_pending_tasks(self) -> bool:
return len(self._queue) > 0

def on_task_complete(self, task_id: str, last_token: int) -> bool:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am pretty sure this is a bad section. Calling a callback on completion is likely going to product bad results.

task = self._schedule_count[task_id].task
task.tokens.append(last_token)
task.seq_len = len(task.tokens)
task.start_position = task.seq_len - 1

if self._schedule_count[task_id].remaining_count == 0:
del self._schedule_count[task_id]
return True

self._queue.append(self._schedule_count[task_id].task)
return False

def remove_task(self, task_id: str):
self._queue = [task for task in self._queue if task.task_id != task_id]
if task_id in self._schedule_count:
del self._schedule_count[task_id]
167 changes: 167 additions & 0 deletions sharktank/sharktank/utils/llm_task.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,167 @@
import dataclasses
import iree.runtime
import numpy

from abc import ABC, abstractmethod
from typing import Callable, List, Optional, Tuple

import torch


def fill_page_table(bs: int, count: int, page_ids: list[list[int]]) -> numpy.ndarray:
pages = numpy.zeros((bs, count), dtype=numpy.int64)

for i, ids in enumerate(page_ids):
pages[i, : len(ids)] = ids[:count]

return pages


@dataclasses.dataclass
class LlmTaskInput:
task_id: str
tokens: List[int]
seq_len: int
pages: List[int]

start_position: Optional[int] = None


class LlmTask(ABC):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would move this to the scheduler and generalize it. Most of the basics are pretty use case vague and it would be easy to write some tests on the scheduler + task

def __init__(
self,
invocation_fn: Callable[
[List[numpy.ndarray | iree.runtime.DeviceArray | torch.Tensor]],
Tuple[numpy.ndarray, Optional[numpy.ndarray]],
],
llm_task_inputs: List[LlmTaskInput],
batch_size: int,
block_stride: int,
):
self._invocation_fn = invocation_fn
self._task_inputs: List[LlmTaskInput] = llm_task_inputs
self._batch_size = batch_size
self._block_stride = block_stride

@abstractmethod
def _prepare_args(
self, task_inputs: List[LlmTaskInput], cache
) -> List[numpy.ndarray | iree.runtime.DeviceArray | torch.Tensor]:
pass

@abstractmethod
def _process_results(
self, results
) -> Tuple[numpy.ndarray, Optional[numpy.ndarray]]:
pass

def run(
self, cache_state: iree.runtime.DeviceArray | torch.Tensor
) -> Tuple[numpy.ndarray, Optional[numpy.ndarray]]:
task_inputs = self._task_inputs

args = self._prepare_args(task_inputs, cache_state)
results = self._invocation_fn(*args)
logits, indices = self._process_results(results)
return logits, indices


class PrefillTask(LlmTask):
def _prepare_args(
self,
task_inputs: List[LlmTaskInput],
cache: iree.runtime.DeviceArray | torch.Tensor,
) -> List[numpy.ndarray | iree.runtime.DeviceArray | torch.Tensor]:
block_stride = self._block_stride
bs = self._batch_size

tokens = [task_input.tokens for task_input in task_inputs]
page_ids = [task_input.pages for task_input in task_inputs]

max_len = max(len(input_tokens) for input_tokens in tokens)
blocks = int(numpy.ceil(max_len / block_stride))
blocked_len = blocks * block_stride

tokens_ = numpy.zeros((bs, blocked_len), dtype=numpy.int64)
lens_ = numpy.ones((bs,), dtype=numpy.int64)

for i, input_tokens in enumerate(tokens):
tokens_[i, : len(input_tokens)] = input_tokens
lens_[i] = len(input_tokens)

pages_ = fill_page_table(bs, blocks, page_ids)
args = [
tokens_,
lens_,
pages_,
cache,
]
return args

def _process_results(
self, results
) -> Tuple[numpy.ndarray, Optional[numpy.ndarray]]:
if isinstance(results, tuple):
logits, indices = results
logits = numpy.asarray(logits)
indices = numpy.asarray(indices)
else:
logits = numpy.asarray(results)
indices = None
return logits, indices


class DecodeTask(LlmTask):
def _prepare_args(
self,
task_inputs: List[LlmTaskInput],
cache: iree.runtime.DeviceArray | torch.Tensor,
) -> List[numpy.ndarray | iree.runtime.DeviceArray | torch.Tensor]:
assert all(
task_input.start_position is not None for task_input in task_inputs
), "`start_positions` is a required argument for `decode`"

block_stride = self._block_stride
decode_bs = self._batch_size
bs = len(task_inputs)

tokens = [task_input.tokens[-1] for task_input in task_inputs]
page_ids = [task_input.pages for task_input in task_inputs]
start_positions = [task_input.start_position for task_input in task_inputs]

max_len = max(start_positions) + 1
blocks = int(numpy.ceil(max_len / block_stride))

tokens_ = numpy.zeros((decode_bs, 1), dtype=numpy.int64)
lens_ = numpy.ones((decode_bs,), dtype=numpy.int64)
pos_ = numpy.ones((decode_bs,), dtype=numpy.int64)

for i in range(bs):
tokens_[i, 0] = tokens[i]
lens_[i] = start_positions[i] + 1
pos_[i] = start_positions[i]

pages_ = fill_page_table(decode_bs, blocks, page_ids)

args = [
tokens_,
lens_,
pos_,
pages_,
cache,
]
return args

def _process_results(
self, results
) -> Tuple[numpy.ndarray, Optional[numpy.ndarray]]:
if isinstance(results, tuple):
logits, indices = results
else:
k = 8
logits = torch.asarray(numpy.asarray(results))
logits, indices = torch.topk(logits, k)

logits = numpy.asarray(logits)
indices = numpy.asarray(indices)
return logits, indices
Loading
Loading