-
Notifications
You must be signed in to change notification settings - Fork 67
[Sharktank] Llm Task Scheduler #2418
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
- Add `task_id` to `LlmTaskInput`
Codecov Report❌ Patch coverage is
Additional details and impacted files@@ Coverage Diff @@
## main #2418 +/- ##
=======================================
Coverage ? 78.21%
=======================================
Files ? 243
Lines ? 22786
Branches ? 0
=======================================
Hits ? 17821
Misses ? 4965
Partials ? 0 ☔ View full report in Codecov by Sentry. 🚀 New features to boost your workflow:
|
…aione/SHARK-Platform into sharktank-llm-task-scheduler
start_position: Optional[int] = None | ||
|
||
|
||
class LlmTask(ABC): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I would move this to the scheduler and generalize it. Most of the basics are pretty use case vague and it would be easy to write some tests on the scheduler + task
pages=page_ids[i], | ||
) | ||
# Submit prefill task | ||
self._prefill_scheduler.schedule_task(task_input, 1) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This feels wrong - why are you submitting to both prefill and decode?
batch_size=self._decode_bs, | ||
block_stride=self._block_stride, | ||
) | ||
logits, indices = decode_task.run(*self._cache) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This looks wrong. You make a task then call run on it then call the scheduler? You should be making a job and submitting it. The scheduler should basically have a schedule
command for each task (a task being a separate request). Then a call for competion. Something is off on this.
last = selection_fn(logits, indices, [0] * len(task_inputs)) | ||
for task_input, token in zip(task_inputs, last): | ||
selections_map[task_input.task_id].append(token) | ||
self._decode_scheduler.on_task_complete( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This on task completion callback is wrong. You should submit then process on the returned results.
token, | ||
) | ||
if token == eos_token: | ||
self._decode_scheduler.remove_task(task_input.task_id) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Schedulers should take requests, process them, then be done. You shouldn't be externally managing.
def has_pending_tasks(self) -> bool: | ||
return len(self._queue) > 0 | ||
|
||
def on_task_complete(self, task_id: str, last_token: int) -> bool: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I am pretty sure this is a bad section. Calling a callback on completion is likely going to product bad results.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Make enough parts in scheduler and write some tests for it. I think you will understand the issues with your API if you attempt to use it standalone vs retrofitting it into the old process.
Description
Scheduler
into thevmfb-runner
submit + run
flow to theBatcher
LlmTaskInput
, which gets added to theScheduler
, with acount
variable.count
var allows the scheduler to determine how long each request should be tracked. For example, thedecode_scheduler
will track each task for at moststeps
number of schedules.Decoder
submits the tasks to theBatcher
, then returns the results from callingrun
Batcher
(i.e. perplexity). Would need to think of a way to handle the fact that PPL needs all of the raw logits, not just the selections, but that should be possible to do. Maybe by moving theselection
logic back outside ofBatcher
, and just providing it a callback for where it should submit the logits + indices.llm_scheduler.py
,llm_task.py
andllm_utils.py
, because the file was getting complexNext Steps
ChunkPrefill
schedulerllm_utils
, but maintain the same cli commands