Skip to content

Commit ae62b28

Browse files
committed
fix mtp chunkprefill output, fix unit test
1 parent 1de9f87 commit ae62b28

File tree

7 files changed

+74
-37
lines changed

7 files changed

+74
-37
lines changed

custom_ops/gpu_ops/cpp_extensions.cc

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -710,8 +710,11 @@ void SpeculateSetValueByFlagsAndIdx(const paddle::Tensor &pre_ids_all,
710710
void SpeculateSaveWithOutputMsgStatic(const paddle::Tensor& accept_tokens,
711711
const paddle::Tensor& accept_num,
712712
const paddle::Tensor& not_need_stop,
713+
const paddle::Tensor& seq_lens_decoder,
714+
const paddle::Tensor& prompt_lens,
713715
int64_t rank_id,
714-
bool save_each_rank);
716+
bool save_each_rank,
717+
bool skip_prefill);
715718

716719

717720
void SpeculateClearAcceptNums(const paddle::Tensor& accept_num,

custom_ops/gpu_ops/speculate_decoding/speculate_save_output.cc

Lines changed: 24 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -28,9 +28,12 @@
2828
void SpeculateSaveWithOutputMsg(const paddle::Tensor& accept_tokens,
2929
const paddle::Tensor& accept_num,
3030
const paddle::Tensor& not_need_stop,
31+
const paddle::Tensor& seq_lens_decoder,
32+
const paddle::Tensor& prompt_lens,
3133
int64_t rank_id,
3234
int msg_queue_id,
33-
int save_each_rank) {
35+
int save_each_rank,
36+
bool skip_prefill) {
3437
// printf("enter save output");
3538
if (!save_each_rank && rank_id > 0) {
3639
return;
@@ -43,6 +46,11 @@ void SpeculateSaveWithOutputMsg(const paddle::Tensor& accept_tokens,
4346
int64_t* accept_tokens_data = accept_tokens_cpu.data<int64_t>();
4447
int* accept_num_data = accept_num_cpu.data<int>();
4548

49+
auto seq_lens_decoder_cpu = seq_lens_decoder.copy_to(paddle::CPUPlace(), true);
50+
auto prompt_lens_cpu = prompt_lens.copy_to(paddle::CPUPlace(), true);
51+
int* seq_lens_decoder_data = seq_lens_decoder_cpu.data<int>();
52+
int64_t* prompt_lens_data = prompt_lens_cpu.data<int64_t>();
53+
4654
if (const char* inference_msg_queue_id_env_p =
4755
std::getenv("INFERENCE_MSG_QUEUE_ID")) {
4856
std::string inference_msg_queue_id_env_str(
@@ -95,7 +103,7 @@ void SpeculateSaveWithOutputMsg(const paddle::Tensor& accept_tokens,
95103
msg_sed.mtext[1] = bsz;
96104

97105
for (int i = 2; i < MAX_BSZ + 2; i++) {
98-
if (i - 2 >= bsz) {
106+
if (i - 2 >= bsz || (skip_prefill && seq_lens_decoder_data[i - 2] < prompt_lens_data[i - 2])) {
99107
msg_sed.mtext[i] = 0;
100108
} else {
101109
msg_sed.mtext[i] = (int)accept_num_data[i - 2];
@@ -125,32 +133,38 @@ void SpeculateSaveWithOutputMsg(const paddle::Tensor& accept_tokens,
125133
void SpeculateSaveWithOutputMsgStatic(const paddle::Tensor& accept_tokens,
126134
const paddle::Tensor& accept_num,
127135
const paddle::Tensor& not_need_stop,
136+
const paddle::Tensor& seq_lens_decoder,
137+
const paddle::Tensor& prompt_lens,
128138
int64_t rank_id,
129-
bool save_each_rank) {
139+
bool save_each_rank,
140+
bool skip_prefill) {
130141
SpeculateSaveWithOutputMsg(
131-
accept_tokens, accept_num, not_need_stop, rank_id, 1, save_each_rank);
142+
accept_tokens, accept_num, not_need_stop, seq_lens_decoder, prompt_lens, rank_id, 1, save_each_rank, skip_prefill);
132143
}
133144

134145
void SpeculateSaveWithOutputMsgDynamic(const paddle::Tensor& accept_tokens,
135146
const paddle::Tensor& accept_num,
136147
const paddle::Tensor& not_need_stop,
148+
const paddle::Tensor& seq_lens_decoder,
149+
const paddle::Tensor& prompt_lens,
137150
int64_t rank_id,
138151
int msg_queue_id,
139-
bool save_each_rank) {
152+
bool save_each_rank,
153+
bool skip_prefill) {
140154
SpeculateSaveWithOutputMsg(
141-
accept_tokens, accept_num, not_need_stop, rank_id, msg_queue_id, save_each_rank);
155+
accept_tokens, accept_num, not_need_stop, seq_lens_decoder, prompt_lens, rank_id, msg_queue_id, save_each_rank, skip_prefill);
142156
}
143157

144158
PD_BUILD_STATIC_OP(speculate_save_output)
145-
.Inputs({"accept_tokens", "accept_num", "not_need_stop"})
146-
.Attrs({"rank_id: int64_t", "save_each_rank: bool"})
159+
.Inputs({"accept_tokens", "accept_num", "not_need_stop", "seq_lens_decoder", "prompt_lens"})
160+
.Attrs({"rank_id: int64_t", "save_each_rank: bool", "skip_prefill: bool"})
147161
.Outputs({"x_out"})
148162
.SetInplaceMap({{"accept_tokens", "x_out"}})
149163
.SetKernelFn(PD_KERNEL(SpeculateSaveWithOutputMsgStatic));
150164

151165
PD_BUILD_STATIC_OP(speculate_save_output_dynamic)
152-
.Inputs({"accept_tokens", "accept_num", "not_need_stop"})
153-
.Attrs({"rank_id: int64_t", "msg_queue_id: int", "save_each_rank: bool"})
166+
.Inputs({"accept_tokens", "accept_num", "not_need_stop", "seq_lens_decoder", "prompt_lens"})
167+
.Attrs({"rank_id: int64_t", "msg_queue_id: int", "save_each_rank: bool", "skip_prefill: bool"})
154168
.Outputs({"x_out"})
155169
.SetInplaceMap({{"accept_tokens", "x_out"}})
156170
.SetKernelFn(PD_KERNEL(SpeculateSaveWithOutputMsgDynamic));

custom_ops/gpu_ops/speculate_decoding/speculate_set_value_by_flags_and_idx.cu

Lines changed: 26 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -20,30 +20,33 @@
2020

2121
__global__ void speculate_set_value_by_flag_and_id(int64_t *pre_ids_all,
2222
const int64_t *accept_tokens,
23-
const int *accept_num,
23+
int *accept_num,
2424
const bool *stop_flags,
2525
const int *seq_lens_encoder,
26-
const int *seq_lens_decoder,
26+
int *seq_lens_decoder,
2727
const int64_t *step_idx,
2828
int bs,
2929
int length,
3030
int max_draft_tokens) {
3131
int tid = threadIdx.x;
32-
if (tid < bs && !stop_flags[tid]) {
33-
int64_t *pre_ids_all_now = pre_ids_all + tid * length;
34-
const int64_t *accept_tokens_now =
35-
accept_tokens + tid * max_draft_tokens;
36-
const int seq_len_dec = seq_lens_decoder[tid];
37-
const int seq_len_enc = seq_lens_encoder[tid];
38-
if (seq_len_dec == 0 && seq_len_enc == 0) return; // stopped
39-
// printf("step_idx[tid] %d\n", step_idx[tid]);
40-
if (step_idx[tid] >= 0) {
41-
for (int i = 0; i < accept_num[tid]; i++) {
42-
pre_ids_all_now[step_idx[tid] - i] =
43-
accept_tokens_now[accept_num[tid] - 1 - i];
44-
// printf("pre_ids_all_now[step_idx[tid] - i] %d \n",
45-
// pre_ids_all_now[step_idx[tid] - i]);
32+
33+
if (tid < bs) {
34+
if (!stop_flags[tid]) {
35+
int64_t *pre_ids_all_now = pre_ids_all + tid * length;
36+
const int64_t *accept_tokens_now =
37+
accept_tokens + tid * max_draft_tokens;
38+
const int seq_len_dec = seq_lens_decoder[tid];
39+
const int seq_len_enc = seq_lens_encoder[tid];
40+
if (seq_len_dec == 0 && seq_len_enc == 0) return; // stoped
41+
if (step_idx[tid] >= 0) {
42+
for (int i = 0; i < accept_num[tid]; i++) {
43+
pre_ids_all_now[step_idx[tid] - i] =
44+
accept_tokens_now[accept_num[tid] - 1 - i];
45+
}
4646
}
47+
} else {
48+
accept_num[tid] = 0;
49+
seq_lens_decoder[tid] = 0;
4750
}
4851
}
4952
}
@@ -67,10 +70,10 @@ void SpeculateSetValueByFlagsAndIdx(const paddle::Tensor &pre_ids_all,
6770
speculate_set_value_by_flag_and_id<<<1, block_size, 0, cu_stream>>>(
6871
const_cast<int64_t *>(pre_ids_all.data<int64_t>()),
6972
accept_tokens.data<int64_t>(),
70-
accept_num.data<int>(),
73+
const_cast<int*>(accept_num.data<int>()),
7174
stop_flags.data<bool>(),
7275
seq_lens_encoder.data<int>(),
73-
seq_lens_decoder.data<int>(),
76+
const_cast<int*>(seq_lens_decoder.data<int>()),
7477
step_idx.data<int64_t>(),
7578
bs,
7679
length,
@@ -86,6 +89,9 @@ PD_BUILD_STATIC_OP(speculate_set_value_by_flags_and_idx)
8689
"seq_lens_encoder",
8790
"seq_lens_decoder",
8891
"step_idx"})
89-
.Outputs({"pre_ids_all_out"})
90-
.SetInplaceMap({{"pre_ids_all", "pre_ids_all_out"}})
92+
.Outputs({"pre_ids_all_out", "accept_num_out", "seq_lens_decoder_out"})
93+
.SetInplaceMap({
94+
{"pre_ids_all", "pre_ids_all_out"},
95+
{"accept_num", "accept_num_out"},
96+
{"seq_lens_decoder", "seq_lens_decoder_out"}})
9197
.SetKernelFn(PD_KERNEL(SpeculateSetValueByFlagsAndIdx));

custom_ops/gpu_ops/speculate_decoding/speculate_update.cu

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -71,9 +71,6 @@ __global__ void speculate_update(int *seq_lens_encoder,
7171
}
7272
draft_tokens[bid * max_draft_tokens] =
7373
accept_tokens[bid * max_draft_tokens + accept_num_now - 1];
74-
if (stop_flag_now_int) {
75-
seq_lens_decoder[bid] = 0;
76-
}
7774
} else if (bid >= real_bsz && bid < max_bsz) {
7875
stop_flag_now_int = 1;
7976
}

fastdeploy/model_executor/pre_and_post_process.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -64,7 +64,6 @@
6464
save_output,
6565
save_output_topk,
6666
set_stop_value_multi_ends,
67-
speculate_clear_accept_nums,
6867
speculate_get_output_padding_offset,
6968
speculate_get_padding_offset,
7069
speculate_get_seq_lens_output,
@@ -369,12 +368,13 @@ def post_process_specualate(
369368
model_output.accept_tokens,
370369
model_output.accept_num,
371370
model_output.not_need_stop,
371+
model_output.seq_lens_decoder,
372+
model_output.prompt_lens,
372373
model_output.mp_rank,
373374
save_each_rank,
375+
envs.ENABLE_V1_KVCACHE_SCHEDULER,
374376
)
375377

376-
speculate_clear_accept_nums(model_output.accept_num, model_output.seq_lens_decoder)
377-
378378
# Update pre_ids through accept tokens
379379

380380
speculate_set_value_by_flags_and_idx(

fastdeploy/worker/output.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -250,6 +250,11 @@ class ModelOutputData:
250250
"""
251251
stop_seqs_len: paddle.Tensor = None
252252

253+
"""
254+
the length of input prompt
255+
"""
256+
prompt_lens: paddle.Tensor = None
257+
253258

254259
@dataclass
255260
class ModelRunnerOutput:

tests/operators/test_speculative_schedule_cache.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,9 @@ def cpu_reference(
1010
draft_tokens,
1111
block_tables,
1212
stop_flags,
13+
prompt_lens,
1314
seq_lens_this_time,
15+
seq_lens_encoder,
1416
seq_lens_decoder,
1517
step_seq_lens_decoder,
1618
step_draft_tokens,
@@ -101,7 +103,9 @@ def setUp(self):
101103
self.block_tables = paddle.to_tensor(np.full((self.real_bsz, self.block_num_per_seq), -1, dtype=np.int32))
102104
# stop_flags length is max_bsz, others are real_bsz
103105
self.stop_flags = paddle.to_tensor(np.array([False, True, False, False, False], dtype=np.bool_))
106+
self.prompt_lens = paddle.to_tensor(np.array([1, 1, 1], dtype=np.int64))
104107
self.seq_lens_this_time = paddle.to_tensor(np.array([5, 6, 7], dtype=np.int32))
108+
self.seq_lens_encoder = paddle.to_tensor(np.array([1, 1, 1], dtype=np.int32))
105109
self.seq_lens_decoder = paddle.to_tensor(np.array([1, 1, 10], dtype=np.int32))
106110

107111
# Will be filled by kernel for the triggering bids only
@@ -129,7 +133,9 @@ def setUp(self):
129133
self.np_draft_tokens = self.draft_tokens.numpy().copy()
130134
self.np_block_tables = self.block_tables.numpy().copy()
131135
self.np_stop_flags = self.stop_flags.numpy().copy()
136+
self.np_prompt_lens = self.prompt_lens.numpy().copy()
132137
self.np_seq_lens_this_time = self.seq_lens_this_time.numpy().copy()
138+
self.np_seq_lens_encoder = self.seq_lens_encoder.numpy().copy()
133139
self.np_seq_lens_decoder = self.seq_lens_decoder.numpy().copy()
134140
self.np_step_seq_lens_decoder = self.step_seq_lens_decoder.numpy().copy()
135141
self.np_step_draft_tokens = self.step_draft_tokens.numpy().copy()
@@ -146,7 +152,9 @@ def test_correctness_against_cpu_reference(self):
146152
self.draft_tokens,
147153
self.block_tables,
148154
self.stop_flags,
155+
self.prompt_lens,
149156
self.seq_lens_this_time,
157+
self.seq_lens_encoder,
150158
self.seq_lens_decoder,
151159
self.step_seq_lens_decoder,
152160
self.step_draft_tokens,
@@ -165,7 +173,9 @@ def test_correctness_against_cpu_reference(self):
165173
self.np_draft_tokens,
166174
self.np_block_tables,
167175
self.np_stop_flags,
176+
self.prompt_lens,
168177
self.np_seq_lens_this_time,
178+
self.np_seq_lens_encoder,
169179
self.np_seq_lens_decoder,
170180
self.np_step_seq_lens_decoder,
171181
self.np_step_draft_tokens,
@@ -213,7 +223,9 @@ def test_no_trigger_path(self):
213223
self.draft_tokens,
214224
self.block_tables,
215225
self.stop_flags,
226+
self.prompt_lens,
216227
self.seq_lens_this_time,
228+
self.seq_lens_encoder,
217229
self.seq_lens_decoder,
218230
self.step_seq_lens_decoder,
219231
self.step_draft_tokens,

0 commit comments

Comments
 (0)