2828void SpeculateSaveWithOutputMsg (const paddle::Tensor& accept_tokens,
2929 const paddle::Tensor& accept_num,
3030 const paddle::Tensor& not_need_stop,
31+ const paddle::Tensor& seq_lens_decoder,
32+ const paddle::Tensor& prompt_lens,
3133 int64_t rank_id,
3234 int msg_queue_id,
33- int save_each_rank) {
35+ int save_each_rank,
36+ bool skip_prefill) {
3437 // printf("enter save output");
3538 if (!save_each_rank && rank_id > 0 ) {
3639 return ;
@@ -43,6 +46,11 @@ void SpeculateSaveWithOutputMsg(const paddle::Tensor& accept_tokens,
4346 int64_t * accept_tokens_data = accept_tokens_cpu.data <int64_t >();
4447 int * accept_num_data = accept_num_cpu.data <int >();
4548
49+ auto seq_lens_decoder_cpu = seq_lens_decoder.copy_to (paddle::CPUPlace (), true );
50+ auto prompt_lens_cpu = prompt_lens.copy_to (paddle::CPUPlace (), true );
51+ int * seq_lens_decoder_data = seq_lens_decoder_cpu.data <int >();
52+ int64_t * prompt_lens_data = prompt_lens_cpu.data <int64_t >();
53+
4654 if (const char * inference_msg_queue_id_env_p =
4755 std::getenv (" INFERENCE_MSG_QUEUE_ID" )) {
4856 std::string inference_msg_queue_id_env_str (
@@ -95,7 +103,7 @@ void SpeculateSaveWithOutputMsg(const paddle::Tensor& accept_tokens,
95103 msg_sed.mtext [1 ] = bsz;
96104
97105 for (int i = 2 ; i < MAX_BSZ + 2 ; i++) {
98- if (i - 2 >= bsz) {
106+ if (i - 2 >= bsz || (skip_prefill && seq_lens_decoder_data[i - 2 ] < prompt_lens_data[i - 2 ]) ) {
99107 msg_sed.mtext [i] = 0 ;
100108 } else {
101109 msg_sed.mtext [i] = (int )accept_num_data[i - 2 ];
@@ -125,32 +133,38 @@ void SpeculateSaveWithOutputMsg(const paddle::Tensor& accept_tokens,
125133void SpeculateSaveWithOutputMsgStatic (const paddle::Tensor& accept_tokens,
126134 const paddle::Tensor& accept_num,
127135 const paddle::Tensor& not_need_stop,
136+ const paddle::Tensor& seq_lens_decoder,
137+ const paddle::Tensor& prompt_lens,
128138 int64_t rank_id,
129- bool save_each_rank) {
139+ bool save_each_rank,
140+ bool skip_prefill) {
130141 SpeculateSaveWithOutputMsg (
131- accept_tokens, accept_num, not_need_stop, rank_id, 1 , save_each_rank);
142+ accept_tokens, accept_num, not_need_stop, seq_lens_decoder, prompt_lens, rank_id, 1 , save_each_rank, skip_prefill );
132143}
133144
134145void SpeculateSaveWithOutputMsgDynamic (const paddle::Tensor& accept_tokens,
135146 const paddle::Tensor& accept_num,
136147 const paddle::Tensor& not_need_stop,
148+ const paddle::Tensor& seq_lens_decoder,
149+ const paddle::Tensor& prompt_lens,
137150 int64_t rank_id,
138151 int msg_queue_id,
139- bool save_each_rank) {
152+ bool save_each_rank,
153+ bool skip_prefill) {
140154 SpeculateSaveWithOutputMsg (
141- accept_tokens, accept_num, not_need_stop, rank_id, msg_queue_id, save_each_rank);
155+ accept_tokens, accept_num, not_need_stop, seq_lens_decoder, prompt_lens, rank_id, msg_queue_id, save_each_rank, skip_prefill );
142156}
143157
144158PD_BUILD_STATIC_OP (speculate_save_output)
145- .Inputs({" accept_tokens" , " accept_num" , " not_need_stop" })
146- .Attrs({" rank_id: int64_t" , " save_each_rank: bool" })
159+ .Inputs({" accept_tokens" , " accept_num" , " not_need_stop" , " seq_lens_decoder " , " prompt_lens " })
160+ .Attrs({" rank_id: int64_t" , " save_each_rank: bool" , " skip_prefill: bool " })
147161 .Outputs({" x_out" })
148162 .SetInplaceMap({{" accept_tokens" , " x_out" }})
149163 .SetKernelFn(PD_KERNEL(SpeculateSaveWithOutputMsgStatic));
150164
151165PD_BUILD_STATIC_OP (speculate_save_output_dynamic)
152- .Inputs({" accept_tokens" , " accept_num" , " not_need_stop" })
153- .Attrs({" rank_id: int64_t" , " msg_queue_id: int" , " save_each_rank: bool" })
166+ .Inputs({" accept_tokens" , " accept_num" , " not_need_stop" , " seq_lens_decoder " , " prompt_lens " })
167+ .Attrs({" rank_id: int64_t" , " msg_queue_id: int" , " save_each_rank: bool" , " skip_prefill: bool " })
154168 .Outputs({" x_out" })
155169 .SetInplaceMap({{" accept_tokens" , " x_out" }})
156170 .SetKernelFn(PD_KERNEL(SpeculateSaveWithOutputMsgDynamic));
0 commit comments