-
Notifications
You must be signed in to change notification settings - Fork 2.8k
[NPUW] Support generate more than 1 token per inference #31578
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[NPUW] Support generate more than 1 token per inference #31578
Conversation
45a6890
to
621189b
Compare
@cwzrad if you bring this for speculative decode, please synchronize with @AsyaPronina to avoid duplication |
@dmatveev yes, we have a sync, holp this npuw change plus with her Genai Pipe changes openvinotoolkit/openvino.genai#2544 can make fast draft work, with some co-ebugging. |
Signed-off-by: wenzengc <[email protected]>
Signed-off-by: wenzengc <[email protected]>
68964c1
to
f74a821
Compare
Supported dynamic number of output tokens for generate model only
AR for @AsyaPronina:
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There's a lot of changes in this file, I expect a thorough review from @AsyaPronina here
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Reviewed -> finalized, thanks!
// number of candidates. To differentiate prefill and generate | ||
// calls for main model, we just check that start position id | ||
// is 0, meaning this is the first input prompt. | ||
if (input_ids->get_shape()[INPUT_IDS_SEQ_LEN_DIM] > 1 && position_ids->data<int64_t>()[0] == 0) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
DM: How it will work with prefix caching?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks like that should work
[DO NOT MERGE][Validation is in progress] Polishing the PR for merge
{char{0x4c}, char{0x4c}, char{0x4d}, char{0x43}, char{0x4d}, char{0x4f}}; | ||
|
||
const constexpr char* NPUW_SERIALIZATION_VERSION = "0.8"; | ||
const constexpr char* NPUW_SERIALIZATION_VERSION = "0.10"; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
why did we skip 0.9? or 0.9 comes with the prefix caching?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, it should!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
so it doesn't make sense then, would you set it back to 0.9 once the prefix caching is merged?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@AsyaPronina I'd recommend to make it 0.9. First come, first served
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Fixed, thanks!
Co-authored-by: Dmitry Matveev <[email protected]>
20efbd5
Details:
For KV cache model, support generate more than 1 token per inference, which is needed by speculative decoding.
Also update the KV according to the position id for fast draft.
Basically if we already saved 20 KV cache, then the next position ID should be 20. Assume in this case we have 3 token inputs, the position id should be [20, 21, 22], after inference, we saved 3 more KV cache, it becomes 23. But after verification in application side, we find the 22 is not a correct token, then for next inference the position id is [22, 23, 24], the position id only increase 2. Then we know in previous inference, the last KV cache is a dirty one.
...
Tickets: