Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 6 additions & 1 deletion custom_ops/gpu_ops/cpp_extensions.cc
Original file line number Diff line number Diff line change
Expand Up @@ -709,8 +709,11 @@ void SpeculateSetValueByFlagsAndIdx(const paddle::Tensor &pre_ids_all,
void SpeculateSaveWithOutputMsgStatic(const paddle::Tensor& accept_tokens,
const paddle::Tensor& accept_num,
const paddle::Tensor& not_need_stop,
const paddle::Tensor& seq_lens_decoder,
const paddle::Tensor& prompt_lens,
int64_t rank_id,
bool save_each_rank);
bool save_each_rank,
bool skip_prefill);


void SpeculateClearAcceptNums(const paddle::Tensor& accept_num,
Expand All @@ -719,7 +722,9 @@ void SpeculateClearAcceptNums(const paddle::Tensor& accept_num,
void SpeculateScheduleCache(const paddle::Tensor &draft_tokens,
const paddle::Tensor &block_tables,
const paddle::Tensor &stop_flags,
const paddle::Tensor &prompt_lens,
const paddle::Tensor &seq_lens_this_time,
const paddle::Tensor &seq_lens_encoder,
const paddle::Tensor &seq_lens_decoder,
const paddle::Tensor &step_seq_lens_decoder,
const paddle::Tensor &step_draft_tokens,
Expand Down
34 changes: 24 additions & 10 deletions custom_ops/gpu_ops/speculate_decoding/speculate_save_output.cc
Original file line number Diff line number Diff line change
Expand Up @@ -28,9 +28,12 @@
void SpeculateSaveWithOutputMsg(const paddle::Tensor& accept_tokens,
const paddle::Tensor& accept_num,
const paddle::Tensor& not_need_stop,
const paddle::Tensor& seq_lens_decoder,
const paddle::Tensor& prompt_lens,
int64_t rank_id,
int msg_queue_id,
int save_each_rank) {
int save_each_rank,
bool skip_prefill) {
// printf("enter save output");
if (!save_each_rank && rank_id > 0) {
return;
Expand All @@ -43,6 +46,11 @@ void SpeculateSaveWithOutputMsg(const paddle::Tensor& accept_tokens,
int64_t* accept_tokens_data = accept_tokens_cpu.data<int64_t>();
int* accept_num_data = accept_num_cpu.data<int>();

auto seq_lens_decoder_cpu = seq_lens_decoder.copy_to(paddle::CPUPlace(), true);
auto prompt_lens_cpu = prompt_lens.copy_to(paddle::CPUPlace(), true);
int* seq_lens_decoder_data = seq_lens_decoder_cpu.data<int>();
int64_t* prompt_lens_data = prompt_lens_cpu.data<int64_t>();

if (const char* inference_msg_queue_id_env_p =
std::getenv("INFERENCE_MSG_QUEUE_ID")) {
std::string inference_msg_queue_id_env_str(
Expand Down Expand Up @@ -95,7 +103,7 @@ void SpeculateSaveWithOutputMsg(const paddle::Tensor& accept_tokens,
msg_sed.mtext[1] = bsz;

for (int i = 2; i < MAX_BSZ + 2; i++) {
if (i - 2 >= bsz) {
if (i - 2 >= bsz || (skip_prefill && seq_lens_decoder_data[i - 2] < prompt_lens_data[i - 2])) {
msg_sed.mtext[i] = 0;
} else {
msg_sed.mtext[i] = (int)accept_num_data[i - 2];
Expand Down Expand Up @@ -125,32 +133,38 @@ void SpeculateSaveWithOutputMsg(const paddle::Tensor& accept_tokens,
void SpeculateSaveWithOutputMsgStatic(const paddle::Tensor& accept_tokens,
const paddle::Tensor& accept_num,
const paddle::Tensor& not_need_stop,
const paddle::Tensor& seq_lens_decoder,
const paddle::Tensor& prompt_lens,
int64_t rank_id,
bool save_each_rank) {
bool save_each_rank,
bool skip_prefill) {
SpeculateSaveWithOutputMsg(
accept_tokens, accept_num, not_need_stop, rank_id, 1, save_each_rank);
accept_tokens, accept_num, not_need_stop, seq_lens_decoder, prompt_lens, rank_id, 1, save_each_rank, skip_prefill);
}

void SpeculateSaveWithOutputMsgDynamic(const paddle::Tensor& accept_tokens,
const paddle::Tensor& accept_num,
const paddle::Tensor& not_need_stop,
const paddle::Tensor& seq_lens_decoder,
const paddle::Tensor& prompt_lens,
int64_t rank_id,
int msg_queue_id,
bool save_each_rank) {
bool save_each_rank,
bool skip_prefill) {
SpeculateSaveWithOutputMsg(
accept_tokens, accept_num, not_need_stop, rank_id, msg_queue_id, save_each_rank);
accept_tokens, accept_num, not_need_stop, seq_lens_decoder, prompt_lens, rank_id, msg_queue_id, save_each_rank, skip_prefill);
}

PD_BUILD_STATIC_OP(speculate_save_output)
.Inputs({"accept_tokens", "accept_num", "not_need_stop"})
.Attrs({"rank_id: int64_t", "save_each_rank: bool"})
.Inputs({"accept_tokens", "accept_num", "not_need_stop", "seq_lens_decoder", "prompt_lens"})
.Attrs({"rank_id: int64_t", "save_each_rank: bool", "skip_prefill: bool"})
.Outputs({"x_out"})
.SetInplaceMap({{"accept_tokens", "x_out"}})
.SetKernelFn(PD_KERNEL(SpeculateSaveWithOutputMsgStatic));

PD_BUILD_STATIC_OP(speculate_save_output_dynamic)
.Inputs({"accept_tokens", "accept_num", "not_need_stop"})
.Attrs({"rank_id: int64_t", "msg_queue_id: int", "save_each_rank: bool"})
.Inputs({"accept_tokens", "accept_num", "not_need_stop", "seq_lens_decoder", "prompt_lens"})
.Attrs({"rank_id: int64_t", "msg_queue_id: int", "save_each_rank: bool", "skip_prefill: bool"})
.Outputs({"x_out"})
.SetInplaceMap({{"accept_tokens", "x_out"}})
.SetKernelFn(PD_KERNEL(SpeculateSaveWithOutputMsgDynamic));
50 changes: 37 additions & 13 deletions custom_ops/gpu_ops/speculate_decoding/speculate_schedule_cache.cu
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,9 @@ __global__ void speculate_schedula_cache(
const int64_t *draft_tokens,
int *block_tables,
bool *stop_flags,
const int64_t* prompt_lens,
int *seq_lens_this_time,
int *seq_lens_encoder,
int *seq_lens_decoder,
int *step_seq_lens_decoder,
int64_t *step_draft_tokens,
Expand All @@ -44,23 +46,37 @@ __global__ void speculate_schedula_cache(
int64_t *step_draft_tokens_now = step_draft_tokens + bid * draft_tokens_len;
int *block_table_now = block_tables + bid * block_num_per_seq;
int64_t *accept_tokens_now = accept_tokens + bid * accept_tokens_len;
const int max_possible_block_idx = (seq_lens_decoder[bid] + max_next_step_tokens) / block_size;
if (max_possible_block_idx < block_num_per_seq && block_table_now[max_possible_block_idx] == -1) {
is_block_step[bid] = true;
step_seq_lens_this_time[bid] = seq_lens_this_time[bid];
seq_lens_this_time[bid] = 0;

if (seq_lens_decoder[bid] >= prompt_lens[bid]) {
// decoder
const int max_possible_block_idx = (seq_lens_decoder[bid] + max_next_step_tokens) / block_size;
if (max_possible_block_idx < block_num_per_seq && block_table_now[max_possible_block_idx] == -1) {
is_block_step[bid] = true;
step_seq_lens_this_time[bid] = seq_lens_this_time[bid];
seq_lens_this_time[bid] = 0;
stop_flags[bid] = true;
stop_flag_now_int = 1;
step_seq_lens_decoder[bid] = seq_lens_decoder[bid];
seq_lens_decoder[bid] = 0;
accept_num[bid] = 0;
for (int i = 0; i < accept_tokens_len; i++) {
accept_tokens_now[i] = -1;
}
for (int i = 0; i < draft_tokens_len; i++) {
step_draft_tokens_now[i] = draft_tokens_now[i];
}
}
} else {
// prefill
stop_flags[bid] = true;
stop_flag_now_int = 1;
step_seq_lens_decoder[bid] = seq_lens_decoder[bid];
seq_lens_this_time[bid] = 0;
seq_lens_decoder[bid] = 0;
seq_lens_encoder[bid] = 0;
accept_num[bid] = 0;
for (int i = 0; i < accept_tokens_len; i++) {
accept_tokens_now[i] = -1;
}
for (int i = 0; i < draft_tokens_len; i++) {
step_draft_tokens_now[i] = draft_tokens_now[i];
}
stop_flag_now_int = 1;
}


} else {
stop_flag_now_int = 1;
}
Expand All @@ -83,7 +99,9 @@ __global__ void speculate_schedula_cache(
void SpeculateScheduleCache(const paddle::Tensor &draft_tokens,
const paddle::Tensor &block_tables,
const paddle::Tensor &stop_flags,
const paddle::Tensor &prompt_lens,
const paddle::Tensor &seq_lens_this_time,
const paddle::Tensor &seq_lens_encoder,
const paddle::Tensor &seq_lens_decoder,
const paddle::Tensor &step_seq_lens_decoder,
const paddle::Tensor &step_draft_tokens,
Expand All @@ -109,7 +127,9 @@ void SpeculateScheduleCache(const paddle::Tensor &draft_tokens,
draft_tokens.data<int64_t>(),
const_cast<int *>(block_tables.data<int>()),
const_cast<bool *>(stop_flags.data<bool>()),
prompt_lens.data<int64_t>(),
const_cast<int *>(seq_lens_this_time.data<int>()),
const_cast<int *>(seq_lens_encoder.data<int>()),
const_cast<int *>(seq_lens_decoder.data<int>()),
const_cast<int *>(step_seq_lens_decoder.data<int>()),
const_cast<int64_t *>(step_draft_tokens.data<int64_t>()),
Expand Down Expand Up @@ -138,7 +158,9 @@ PD_BUILD_STATIC_OP(speculate_schedule_cache)
.Inputs({"draft_tokens",
"block_tables",
"stop_flags",
"prompt_lens",
"seq_lens_this_time",
"seq_lens_encoder",
"seq_lens_decoder",
"step_seq_lens_decoder",
"step_draft_tokens",
Expand All @@ -153,6 +175,7 @@ PD_BUILD_STATIC_OP(speculate_schedule_cache)
"block_tables_out",
"stop_flags_out",
"seq_lens_this_time_out",
"seq_lens_encoder_out",
"seq_lens_decoder_out",
"step_seq_lens_decoder_out",
"step_draft_tokens_out",
Expand All @@ -165,6 +188,7 @@ PD_BUILD_STATIC_OP(speculate_schedule_cache)
{"block_tables", "block_tables_out"},
{"stop_flags", "stop_flags_out"},
{"seq_lens_this_time", "seq_lens_this_time_out"},
{"seq_lens_encoder", "seq_lens_encoder_out"},
{"seq_lens_decoder", "seq_lens_decoder_out"},
{"step_seq_lens_decoder", "step_seq_lens_decoder_out"},
{"step_draft_tokens", "step_draft_tokens_out"},
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,30 +20,33 @@

__global__ void speculate_set_value_by_flag_and_id(int64_t *pre_ids_all,
const int64_t *accept_tokens,
const int *accept_num,
int *accept_num,
const bool *stop_flags,
const int *seq_lens_encoder,
const int *seq_lens_decoder,
int *seq_lens_decoder,
const int64_t *step_idx,
int bs,
int length,
int max_draft_tokens) {
int tid = threadIdx.x;
if (tid < bs && !stop_flags[tid]) {
int64_t *pre_ids_all_now = pre_ids_all + tid * length;
const int64_t *accept_tokens_now =
accept_tokens + tid * max_draft_tokens;
const int seq_len_dec = seq_lens_decoder[tid];
const int seq_len_enc = seq_lens_encoder[tid];
if (seq_len_dec == 0 && seq_len_enc == 0) return; // stopped
// printf("step_idx[tid] %d\n", step_idx[tid]);
if (step_idx[tid] >= 0) {
for (int i = 0; i < accept_num[tid]; i++) {
pre_ids_all_now[step_idx[tid] - i] =
accept_tokens_now[accept_num[tid] - 1 - i];
// printf("pre_ids_all_now[step_idx[tid] - i] %d \n",
// pre_ids_all_now[step_idx[tid] - i]);

if (tid < bs) {
if (!stop_flags[tid]) {
int64_t *pre_ids_all_now = pre_ids_all + tid * length;
const int64_t *accept_tokens_now =
accept_tokens + tid * max_draft_tokens;
const int seq_len_dec = seq_lens_decoder[tid];
const int seq_len_enc = seq_lens_encoder[tid];
if (seq_len_dec == 0 && seq_len_enc == 0) return; // stoped
if (step_idx[tid] >= 0) {
for (int i = 0; i < accept_num[tid]; i++) {
pre_ids_all_now[step_idx[tid] - i] =
accept_tokens_now[accept_num[tid] - 1 - i];
}
}
} else {
accept_num[tid] = 0;
seq_lens_decoder[tid] = 0;
}
}
}
Expand All @@ -67,10 +70,10 @@ void SpeculateSetValueByFlagsAndIdx(const paddle::Tensor &pre_ids_all,
speculate_set_value_by_flag_and_id<<<1, block_size, 0, cu_stream>>>(
const_cast<int64_t *>(pre_ids_all.data<int64_t>()),
accept_tokens.data<int64_t>(),
accept_num.data<int>(),
const_cast<int*>(accept_num.data<int>()),
stop_flags.data<bool>(),
seq_lens_encoder.data<int>(),
seq_lens_decoder.data<int>(),
const_cast<int*>(seq_lens_decoder.data<int>()),
step_idx.data<int64_t>(),
bs,
length,
Expand All @@ -86,6 +89,9 @@ PD_BUILD_STATIC_OP(speculate_set_value_by_flags_and_idx)
"seq_lens_encoder",
"seq_lens_decoder",
"step_idx"})
.Outputs({"pre_ids_all_out"})
.SetInplaceMap({{"pre_ids_all", "pre_ids_all_out"}})
.Outputs({"pre_ids_all_out", "accept_num_out", "seq_lens_decoder_out"})
.SetInplaceMap({
{"pre_ids_all", "pre_ids_all_out"},
{"accept_num", "accept_num_out"},
{"seq_lens_decoder", "seq_lens_decoder_out"}})
.SetKernelFn(PD_KERNEL(SpeculateSetValueByFlagsAndIdx));
3 changes: 0 additions & 3 deletions custom_ops/gpu_ops/speculate_decoding/speculate_update.cu
Original file line number Diff line number Diff line change
Expand Up @@ -71,9 +71,6 @@ __global__ void speculate_update(int *seq_lens_encoder,
}
draft_tokens[bid * max_draft_tokens] =
accept_tokens[bid * max_draft_tokens + accept_num_now - 1];
if (stop_flag_now_int) {
seq_lens_decoder[bid] = 0;
}
} else if (bid >= real_bsz && bid < max_bsz) {
stop_flag_now_int = 1;
}
Expand Down
6 changes: 1 addition & 5 deletions fastdeploy/engine/args_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1026,11 +1026,7 @@ def create_engine_config(self) -> FDConfig:

speculative_cfg = self.create_speculative_config()
if not self.enable_chunked_prefill:
if (
current_platform.is_cuda()
and self.splitwise_role == "mixed"
and (speculative_cfg is None or speculative_cfg.method not in ["mtp"])
):
if current_platform.is_cuda() and self.splitwise_role == "mixed":
# default enable chunked prefill
self.enable_chunked_prefill = True

Expand Down
6 changes: 3 additions & 3 deletions fastdeploy/model_executor/pre_and_post_process.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,6 @@
save_output,
save_output_topk,
set_stop_value_multi_ends,
speculate_clear_accept_nums,
speculate_get_output_padding_offset,
speculate_get_padding_offset,
speculate_get_seq_lens_output,
Expand Down Expand Up @@ -369,12 +368,13 @@ def post_process_specualate(
model_output.accept_tokens,
model_output.accept_num,
model_output.not_need_stop,
model_output.seq_lens_decoder,
model_output.prompt_lens,
model_output.mp_rank,
save_each_rank,
envs.ENABLE_V1_KVCACHE_SCHEDULER,
)

speculate_clear_accept_nums(model_output.accept_num, model_output.seq_lens_decoder)

# Update pre_ids through accept tokens

speculate_set_value_by_flags_and_idx(
Expand Down
4 changes: 4 additions & 0 deletions fastdeploy/worker/gpu_model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -1460,6 +1460,7 @@ def _dummy_run(
reasoning_index=self.share_inputs["reasoning_index"],
stop_token_ids=self.share_inputs["stop_seqs"],
stop_seqs_len=self.share_inputs["stop_seqs_len"],
prompt_lens=self.share_inputs["prompt_lens"],
)

post_process(
Expand Down Expand Up @@ -1814,6 +1815,7 @@ class at the server level, which is too granular for ModelRunner.
reasoning_index=self.share_inputs["reasoning_index"][:num_running_requests],
stop_token_ids=self.share_inputs["stop_seqs"],
stop_seqs_len=self.share_inputs["stop_seqs_len"],
prompt_lens=self.share_inputs["prompt_lens"],
)

if self.speculative_config.method in ["mtp"] and self.scheduler_config.splitwise_role == "prefill":
Expand Down Expand Up @@ -1860,7 +1862,9 @@ class at the server level, which is too granular for ModelRunner.
self.share_inputs["draft_tokens"],
self.share_inputs["block_tables"],
self.share_inputs["stop_flags"],
self.share_inputs["prompt_lens"],
self.share_inputs["seq_lens_this_time"],
self.share_inputs["seq_lens_encoder"],
self.share_inputs["seq_lens_decoder"],
self.share_inputs["step_seq_lens_decoder"],
self.share_inputs["step_draft_tokens"],
Expand Down
5 changes: 5 additions & 0 deletions fastdeploy/worker/output.py
Original file line number Diff line number Diff line change
Expand Up @@ -250,6 +250,11 @@ class ModelOutputData:
"""
stop_seqs_len: paddle.Tensor = None

"""
the length of input prompt
"""
prompt_lens: paddle.Tensor = None


@dataclass
class ModelRunnerOutput:
Expand Down
3 changes: 0 additions & 3 deletions tests/operators/test_speculate_update.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,9 +70,6 @@ def speculate_update_np(

draft_tokens[bid, 0] = accept_tokens[bid, accept_num[bid] - 1]

if stop_flag_now_int:
seq_lens_decoder[bid] = 0

elif inactive:
stop_flag_now_int = 1

Expand Down
Loading
Loading