diff --git a/custom_ops/gpu_ops/cpp_extensions.cc b/custom_ops/gpu_ops/cpp_extensions.cc index f7228bcfe5..956805802f 100644 --- a/custom_ops/gpu_ops/cpp_extensions.cc +++ b/custom_ops/gpu_ops/cpp_extensions.cc @@ -332,7 +332,9 @@ paddle::Tensor RebuildPaddingFunc( const paddle::Tensor &seq_lens_decoder, const paddle::Tensor &seq_lens_encoder, const paddle::optional &output_padding_offset, - int max_input_length); + const paddle::optional &first_token_out, + int max_input_length, + bool enable_logprob); void GetStopFlagsMulti(const paddle::Tensor &topk_ids, const paddle::Tensor &stop_flags, @@ -906,6 +908,32 @@ void SaveOutMmsgStatic(const paddle::Tensor& x, int64_t rank_id, bool save_each_rank); +void SpeculateGetLogits(const paddle::Tensor &draft_logits, + const paddle::Tensor &next_token_num, + const paddle::Tensor &batch_token_num, + const paddle::Tensor &cu_next_token_offset, + const paddle::Tensor &cu_batch_token_offset, + const paddle::Tensor &logits, + const paddle::Tensor &first_token_logits, + const paddle::Tensor &seq_lens_this_time, + const paddle::Tensor &seq_lens_encoder); + +void SpeculateInsertFirstToken(const paddle::Tensor &token_ids, + const paddle::Tensor &accept_tokens, + const paddle::Tensor &next_tokens, + const paddle::Tensor &cu_next_token_offset, + const paddle::Tensor &cu_batch_token_offset, + const paddle::Tensor &seq_lens_this_time, + const paddle::Tensor &seq_lens_encoder); + +void SpeculateGetTargetLogits(const paddle::Tensor &target_logits, + const paddle::Tensor &logits, + const paddle::Tensor &cu_batch_token_offset, + const paddle::Tensor &ori_cu_batch_token_offset, + const paddle::Tensor &seq_lens_this_time, + const paddle::Tensor &seq_lens_encoder, + const paddle::Tensor &accept_num); + PYBIND11_MODULE(fastdeploy_ops, m) { m.def("get_expert_token_num", &GetExpertTokenNum, py::arg("topk_ids"), @@ -1294,4 +1322,10 @@ PYBIND11_MODULE(fastdeploy_ops, m) { m.def("min_p_sampling", &MinPSamplingFromProbs, "min_p_sampling function"); m.def("save_output", &SaveOutMmsgStatic, "save_output function"); + + m.def("speculate_get_logits", &SpeculateGetLogits, "speculate_get_logits function"); + + m.def("speculate_insert_first_token", &SpeculateInsertFirstToken, "speculate_insert_first_token function"); + + m.def("speculate_get_target_logits", &SpeculateGetTargetLogits, "speculate_get_target_logits function"); } diff --git a/custom_ops/gpu_ops/rebuild_padding.cu b/custom_ops/gpu_ops/rebuild_padding.cu index 93c1bb38c2..0e94cd808b 100644 --- a/custom_ops/gpu_ops/rebuild_padding.cu +++ b/custom_ops/gpu_ops/rebuild_padding.cu @@ -46,6 +46,7 @@ __global__ void RebuildPaddingKernel(T *output_data, template __global__ void RebuildAppendPaddingKernel(T *output_data, + T *first_token_out, const T *input_data, const int *cu_seqlens_q, const int *seq_len_this_time, @@ -55,7 +56,8 @@ __global__ void RebuildAppendPaddingKernel(T *output_data, const int max_input_length, const int dim_embed, const int64_t output_elem_nums, - const int bsz) { + const int bsz, + const bool enable_logprob) { AlignedVector src_vec; const int64_t global_idx = blockDim.x * blockIdx.x + threadIdx.x; for (int64_t i = global_idx * VecSize; i < output_elem_nums; @@ -70,13 +72,20 @@ __global__ void RebuildAppendPaddingKernel(T *output_data, if (seq_len_decoder[bi] == 0 && seq_len_encoder[bi] == 0) continue; if (seq_len_encoder[bi] > 0) seq_id = seq_len_encoder[bi] - 1; - const int cum_offset_bi = bi * max_input_length - cu_seqlens_q[bi]; + const int cum_offset_bi = bi * max_input_length - cu_seqlens_q[bi]; const int input_token_id = ori_token_id - cum_offset_bi + seq_id; const int bias_idx = i % dim_embed; Load(&input_data[input_token_id * dim_embed + bias_idx], &src_vec); Store(src_vec, &output_data[i]); + + if (enable_logprob && seq_len_encoder[bi] > 0) { + const int first_input_token_id = input_token_id - 1; + Load(&input_data[first_input_token_id * dim_embed + bias_idx], + &src_vec); + Store(src_vec, &first_token_out[bi * dim_embed + bias_idx]); + } } } @@ -89,7 +98,9 @@ std::vector rebuild_padding( const paddle::Tensor &seq_lens_decoder, const paddle::Tensor &seq_lens_encoder, const paddle::optional &output_padding_offset, - int max_input_length) { + const paddle::optional &first_token_out, + int max_input_length, + bool enable_logprob) { typedef PDTraits traits_; typedef typename traits_::DataType DataType_; typedef typename traits_::data_t data_t; @@ -135,6 +146,10 @@ std::vector rebuild_padding( RebuildAppendPaddingKernel <<>>( reinterpret_cast(out.data()), + first_token_out.is_initialized() + ? reinterpret_cast(const_cast( + first_token_out.get_ptr()->data())) + : nullptr, reinterpret_cast(tmp_out.data()), cu_seqlens_q.data(), seq_len_this_time.data(), @@ -144,7 +159,8 @@ std::vector rebuild_padding( max_input_length, dim_embed, elem_nums, - bsz); + bsz, + enable_logprob); } else { RebuildPaddingKernel <<>>( @@ -169,7 +185,9 @@ paddle::Tensor RebuildPaddingFunc( const paddle::Tensor &seq_lens_decoder, const paddle::Tensor &seq_lens_encoder, const paddle::optional &output_padding_offset, - int max_input_length) { + const paddle::optional &first_token_out, + int max_input_length, + bool enable_logprob) { switch (tmp_out.type()) { case paddle::DataType::BFLOAT16: { return rebuild_padding( @@ -179,7 +197,9 @@ paddle::Tensor RebuildPaddingFunc( seq_lens_decoder, seq_lens_encoder, output_padding_offset, - max_input_length)[0]; + first_token_out, + max_input_length, + enable_logprob)[0]; } case paddle::DataType::FLOAT16: { return rebuild_padding( @@ -189,7 +209,9 @@ paddle::Tensor RebuildPaddingFunc( seq_lens_decoder, seq_lens_encoder, output_padding_offset, - max_input_length)[0]; + first_token_out, + max_input_length, + enable_logprob)[0]; } case paddle::DataType::FLOAT32: { return rebuild_padding( @@ -199,7 +221,9 @@ paddle::Tensor RebuildPaddingFunc( seq_lens_decoder, seq_lens_encoder, output_padding_offset, - max_input_length)[0]; + first_token_out, + max_input_length, + enable_logprob)[0]; } default: { PD_THROW( @@ -217,14 +241,18 @@ std::vector RebuildPadding( const paddle::Tensor &seq_lens_decoder, const paddle::Tensor &seq_lens_encoder, const paddle::optional &output_padding_offset, - int max_input_length) { + const paddle::optional &first_token_out, + int max_input_length, + bool enable_logprob) { return {RebuildPaddingFunc(tmp_out, cu_seqlens_q, seq_len_this_time, seq_lens_decoder, seq_lens_encoder, output_padding_offset, - max_input_length)}; + first_token_out, + max_input_length, + enable_logprob)}; } std::vector> RebuildPaddingInferShape( @@ -260,9 +288,10 @@ PD_BUILD_STATIC_OP(rebuild_padding) "seq_len_this_time", "seq_lens_decoder", "seq_lens_encoder", - paddle::Optional("output_padding_offset")}) + paddle::Optional("output_padding_offset"), + paddle::Optional("first_token_out")}) .Outputs({"out"}) - .Attrs({"max_input_length: int"}) + .Attrs({"max_input_length: int", "enable_logprob: bool"}) .SetKernelFn(PD_KERNEL(RebuildPadding)) .SetInferShapeFn(PD_INFER_SHAPE(RebuildPaddingInferShape)) .SetInferDtypeFn(PD_INFER_DTYPE(RebuildPaddingInferDtype)); diff --git a/custom_ops/gpu_ops/speculate_decoding/speculate_get_output_with_topk.cc b/custom_ops/gpu_ops/speculate_decoding/speculate_get_output_with_topk.cc new file mode 100644 index 0000000000..c87b27d1e2 --- /dev/null +++ b/custom_ops/gpu_ops/speculate_decoding/speculate_get_output_with_topk.cc @@ -0,0 +1,161 @@ +// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include +#include +#include +#include +#include "paddle/extension.h" + +#ifndef PD_BUILD_STATIC_OP +#define PD_BUILD_STATIC_OP(name) PD_BUILD_OP(static_op_##name) +#endif + +#define MAX_BSZ 512 +#define K 20 +#define MAX_DRAFT_TOKEN_NUM 6 + +struct batch_msgdata { + int tokens[MAX_DRAFT_TOKEN_NUM * (K + 1)]; + float scores[MAX_DRAFT_TOKEN_NUM * (K + 1)]; + int ranks[MAX_DRAFT_TOKEN_NUM]; +}; + +struct msgdata { + long mtype; + int meta[3 + MAX_BSZ]; // stop_flag, message_flag, bsz, batch_token_nums + batch_msgdata mtext[MAX_BSZ]; +}; + +void SpeculateGetOutMmsgTopK(const paddle::Tensor& output_tokens, + const paddle::Tensor& output_scores, + const paddle::Tensor& output_ranks, + int real_k, + int64_t rank_id, + bool wait_flag) { + struct msgdata msg_rcv; + int msg_queue_id = 1; + + if (const char* inference_msg_queue_id_env_p = + std::getenv("INFERENCE_MSG_QUEUE_ID")) { + std::string inference_msg_queue_id_env_str( + inference_msg_queue_id_env_p); + int inference_msg_queue_id_from_env = + std::stoi(inference_msg_queue_id_env_str); +#ifdef SPECULATE_GET_WITH_OUTPUT_DEBUG + std::cout << "Your INFERENCE_MSG_QUEUE_ID is: " + << inference_msg_queue_id_from_env << std::endl; +#endif + msg_queue_id = inference_msg_queue_id_from_env; + } + static key_t key = ftok("/dev/shm", msg_queue_id); + + static int msgid = msgget(key, IPC_CREAT | 0666); +#ifdef SPECULATE_GET_WITH_OUTPUT_DEBUG + std::cout << "get_output_key: " << key << std::endl; + std::cout << "get_output msgid: " << msgid << std::endl; +#endif + + int64_t* output_tokens_data = + const_cast(output_tokens.data()); + float* output_scores_data = const_cast(output_scores.data()); + int64_t* output_ranks_data = + const_cast(output_ranks.data()); + int ret = -1; + if (!wait_flag) { + ret = msgrcv( + msgid, &msg_rcv, sizeof(msg_rcv) - sizeof(long), 0, IPC_NOWAIT); + } else { + ret = msgrcv(msgid, &msg_rcv, sizeof(msg_rcv) - sizeof(long), 0, 0); + } + if (ret == -1) { + // read none + output_tokens_data[0] = -2; // stop_flag + output_tokens_data[1] = 0; // message_flag, Target: 3, Draft: 4 + output_tokens_data[2] = 0; // bsz + return; + } + + int bsz = msg_rcv.meta[2]; + output_tokens_data[0] = (int64_t)msg_rcv.meta[0]; + output_tokens_data[1] = (int64_t)msg_rcv.meta[1]; + output_tokens_data[2] = (int64_t)msg_rcv.meta[2]; + + int output_tokens_offset = 3 + MAX_BSZ; + for (int i = 0; i < bsz; i++) { + int cur_token_num = msg_rcv.meta[3 + i]; + output_tokens_data[3 + i] = (int64_t)cur_token_num; // batch_token_nums + + auto* cur_output_token = output_tokens_data + output_tokens_offset + + i * (MAX_DRAFT_TOKEN_NUM * (K + 1)); + auto* cur_output_score = + output_scores_data + i * (MAX_DRAFT_TOKEN_NUM * (K + 1)); + auto* cur_batch_msg_rcv = &msg_rcv.mtext[i]; + for (int j = 0; j < cur_token_num; j++) { + for (int k = 0; k < real_k + 1; k++) { + cur_output_token[j * (K + 1) + k] = + (int64_t)cur_batch_msg_rcv->tokens[j * (K + 1) + k]; + cur_output_score[j * (K + 1) + k] = + cur_batch_msg_rcv->scores[j * (K + 1) + k]; + } + output_ranks_data[i * MAX_DRAFT_TOKEN_NUM + j] = + (int64_t)cur_batch_msg_rcv->ranks[j]; + } + } +#ifdef SPECULATE_GET_WITH_OUTPUT_DEBUG + std::cout << "msg data: " << std::endl; + std::cout << "stop_flag: " << output_tokens_data[0] + << ", message_flag: " << output_tokens_data[1] + << ", bsz: " << output_tokens_data[2] << std::endl; + for (int i = 0; i < output_tokens_data[2]; i++) { + int cur_token_num = output_tokens_data[3 + i]; + std::cout << "batch " << i << " token_num: " << cur_token_num + << std::endl; + for (int j = 0; j < cur_token_num; j++) { + std::cout << "tokens: "; + for (int k = 0; k < K + 1; k++) { + std::cout + << output_tokens_data[output_tokens_offset + + i * MAX_DRAFT_TOKEN_NUM * (K + 1) + + j * (K + 1) + k] + << " "; + } + std::cout << std::endl; + std::cout << "scores: "; + for (int k = 0; k < K + 1; k++) { + std::cout + << output_scores_data[i * MAX_DRAFT_TOKEN_NUM * (K + 1) + + j * (K + 1) + k] + << " "; + } + std::cout << std::endl; + std::cout << "ranks: " + << output_ranks_data[i * MAX_DRAFT_TOKEN_NUM + j] + << std::endl; + } + } + std::cout << std::endl; +#endif + return; +} + +PD_BUILD_STATIC_OP(speculate_get_output_topk) + .Inputs({"output_tokens", "output_scores", "output_ranks"}) + .Attrs({"real_k: int", "rank_id: int64_t", "wait_flag: bool"}) + .Outputs({"output_tokens_out", "output_scores_out", "output_ranks_out"}) + .SetInplaceMap({{"output_tokens", "output_tokens_out"}, + {"output_scores", "output_scores_out"}, + {"output_ranks", "output_ranks_out"}}) + .SetKernelFn(PD_KERNEL(SpeculateGetOutMmsgTopK)); diff --git a/custom_ops/gpu_ops/speculate_decoding/speculate_logprob_utils.cu b/custom_ops/gpu_ops/speculate_decoding/speculate_logprob_utils.cu new file mode 100644 index 0000000000..a20773bf57 --- /dev/null +++ b/custom_ops/gpu_ops/speculate_decoding/speculate_logprob_utils.cu @@ -0,0 +1,290 @@ +// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "helper.h" + +__global__ void get_token_num_per_batch_kernel(int* next_token_num, + int* batch_token_num, + const int* seq_lens_this_time, + const int* seq_lens_encoder, + const int real_bsz) { + int bid = threadIdx.x; + if (bid < real_bsz) { + next_token_num[bid] = + seq_lens_encoder[bid] > 0 ? 1 : seq_lens_this_time[bid]; + batch_token_num[bid] = + seq_lens_encoder[bid] > 0 ? 2 : seq_lens_this_time[bid]; + } +} + +template +__global__ void speculate_get_logits_kernel(float* draft_logits, + const float* logits, + const float* first_token_logits, + const int* cu_next_token_offset, + const int* cu_batch_token_offset, + const int* seq_lens_this_time, + const int* seq_lens_encoder, + const int vocab_size, + const int real_bsz) { + AlignedVector src_vec; + const int bid = blockIdx.x; + const int tid = threadIdx.x; + if (bid < real_bsz) { + auto* draft_logits_now = + draft_logits + cu_batch_token_offset[bid] * vocab_size; + auto* logits_now = logits + cu_next_token_offset[bid] * vocab_size; + for (int i = tid * VecSize; i < vocab_size; i += blockDim.x * VecSize) { + if (seq_lens_encoder[bid] > 0) { + Load(&first_token_logits[bid * vocab_size + i], + &src_vec); + Store(src_vec, &draft_logits_now[i]); + + Load(&logits_now[i], &src_vec); + Store(src_vec, + &draft_logits_now[vocab_size + i]); + } else { + for (int j = 0; j < seq_lens_this_time[bid]; j++) { + Load(&logits_now[j * vocab_size + i], + &src_vec); + Store( + src_vec, &draft_logits_now[j * vocab_size + i]); + } + } + } + } +} + +void SpeculateGetLogits(const paddle::Tensor& draft_logits, + const paddle::Tensor& next_token_num, + const paddle::Tensor& batch_token_num, + const paddle::Tensor& cu_next_token_offset, + const paddle::Tensor& cu_batch_token_offset, + const paddle::Tensor& logits, + const paddle::Tensor& first_token_logits, + const paddle::Tensor& seq_lens_this_time, + const paddle::Tensor& seq_lens_encoder) { + auto cu_stream = seq_lens_this_time.stream(); + const int vocab_size = logits.shape()[1]; + const int real_bsz = seq_lens_this_time.shape()[0]; + + get_token_num_per_batch_kernel<<<1, 512, 0, cu_stream>>>( + const_cast(next_token_num.data()), + const_cast(batch_token_num.data()), + seq_lens_this_time.data(), + seq_lens_encoder.data(), + real_bsz); + + void* temp_storage1 = nullptr; + size_t temp_storage_bytes1 = 0; + cub::DeviceScan::InclusiveSum( + temp_storage1, + temp_storage_bytes1, + batch_token_num.data(), + const_cast(&cu_batch_token_offset.data()[1]), + real_bsz, + cu_stream); + cudaMalloc(&temp_storage1, temp_storage_bytes1); + cub::DeviceScan::InclusiveSum( + temp_storage1, + temp_storage_bytes1, + batch_token_num.data(), + const_cast(&cu_batch_token_offset.data()[1]), + real_bsz, + cu_stream); + + void* temp_storage2 = nullptr; + size_t temp_storage_bytes2 = 0; + cub::DeviceScan::InclusiveSum( + temp_storage2, + temp_storage_bytes2, + next_token_num.data(), + const_cast(&cu_next_token_offset.data()[1]), + real_bsz, + cu_stream); + cudaMalloc(&temp_storage2, temp_storage_bytes2); + cub::DeviceScan::InclusiveSum( + temp_storage2, + temp_storage_bytes2, + next_token_num.data(), + const_cast(&cu_next_token_offset.data()[1]), + real_bsz, + cu_stream); + + constexpr int PackSize = VEC_16B / sizeof(float); + dim3 grid_dim(real_bsz); + dim3 block_dim(128); + speculate_get_logits_kernel + <<>>( + const_cast(draft_logits.data()), + logits.data(), + first_token_logits.data(), + cu_next_token_offset.data(), + cu_batch_token_offset.data(), + seq_lens_this_time.data(), + seq_lens_encoder.data(), + vocab_size, + real_bsz); +} + +__global__ void speculate_insert_first_token_kernel( + int64_t* token_ids, + const int64_t* accept_tokens, + const int64_t* next_tokens, + const int* cu_next_token_offset, + const int* cu_batch_token_offset, + const int* seq_lens_this_time, + const int* seq_lens_encoder, + const int max_draft_tokens, + const int real_bsz) { + const int bid = threadIdx.x; + + auto* token_ids_now = token_ids + cu_batch_token_offset[bid]; + auto* accept_tokens_now = accept_tokens + bid * max_draft_tokens; + auto* next_tokens_now = next_tokens + cu_next_token_offset[bid]; + if (seq_lens_encoder[bid] != 0) { + token_ids_now[0] = accept_tokens_now[0]; + token_ids_now[1] = next_tokens_now[0]; + } else { + for (int i = 0; i < seq_lens_this_time[bid]; i++) { + token_ids_now[i] = next_tokens_now[i]; + } + } +} + +void SpeculateInsertFirstToken(const paddle::Tensor& token_ids, + const paddle::Tensor& accept_tokens, + const paddle::Tensor& next_tokens, + const paddle::Tensor& cu_next_token_offset, + const paddle::Tensor& cu_batch_token_offset, + const paddle::Tensor& seq_lens_this_time, + const paddle::Tensor& seq_lens_encoder) { + auto cu_stream = seq_lens_this_time.stream(); + const int max_draft_tokens = accept_tokens.shape()[1]; + const int real_bsz = seq_lens_this_time.shape()[0]; + + speculate_insert_first_token_kernel<<<1, real_bsz, 0, cu_stream>>>( + const_cast(token_ids.data()), + accept_tokens.data(), + next_tokens.data(), + cu_next_token_offset.data(), + cu_batch_token_offset.data(), + seq_lens_this_time.data(), + seq_lens_encoder.data(), + max_draft_tokens, + real_bsz); +} + +template +__global__ void speculate_get_target_logits_kernel( + float* target_logtis, + const float* logits, + const int* cu_batch_token_offset, + const int* ori_cu_batch_token_offset, + const int* seq_lens_this_time, + const int* seq_lens_encoder, + const int* accept_num, + const int vocab_size, + const int real_bsz) { + AlignedVector src_vec; + const int bid = blockIdx.x; + const int tid = threadIdx.x; + if (bid < real_bsz) { + auto* target_logtis_now = + target_logtis + cu_batch_token_offset[bid] * vocab_size; + auto* logits_now = logits + ori_cu_batch_token_offset[bid] * vocab_size; + for (int i = tid * VecSize; i < vocab_size; i += blockDim.x * VecSize) { + if (seq_lens_encoder[bid] > 0) { + Load(&logits_now[i], &src_vec); + Store(src_vec, &target_logtis_now[i]); + } else { + for (int j = 0; j < accept_num[bid]; j++) { + Load(&logits_now[j * vocab_size + i], + &src_vec); + Store( + src_vec, &target_logtis_now[j * vocab_size + i]); + } + } + } + } +} + +void SpeculateGetTargetLogits(const paddle::Tensor& target_logits, + const paddle::Tensor& logits, + const paddle::Tensor& cu_batch_token_offset, + const paddle::Tensor& ori_cu_batch_token_offset, + const paddle::Tensor& seq_lens_this_time, + const paddle::Tensor& seq_lens_encoder, + const paddle::Tensor& accept_num) { + auto cu_stream = seq_lens_this_time.stream(); + const int vocab_size = logits.shape()[1]; + const int real_bsz = seq_lens_this_time.shape()[0]; + + constexpr int PackSize = VEC_16B / sizeof(float); + dim3 grid_dim(real_bsz); + dim3 block_dim(128); + speculate_get_target_logits_kernel + <<>>( + const_cast(target_logits.data()), + logits.data(), + cu_batch_token_offset.data(), + ori_cu_batch_token_offset.data(), + seq_lens_this_time.data(), + seq_lens_encoder.data(), + accept_num.data(), + vocab_size, + real_bsz); +} + +PD_BUILD_STATIC_OP(speculate_get_logits) + .Inputs({"draft_logits", + "next_token_num", + "batch_token_num", + "cu_next_token_offset", + "cu_batch_token_offset", + "logits", + "first_token_logits", + "seq_lens_this_time", + "seq_lens_encoder"}) + .Outputs({"draft_logits_out", + "batch_token_num_out", + "cu_batch_token_offset_out"}) + .SetInplaceMap({{"draft_logits", "draft_logits_out"}, + {"batch_token_num", "batch_token_num_out"}, + {"cu_batch_token_offset", "cu_batch_token_offset_out"}}) + .SetKernelFn(PD_KERNEL(SpeculateGetLogits)); + +PD_BUILD_STATIC_OP(speculate_insert_first_token) + .Inputs({"token_ids", + "accept_tokens", + "next_tokens", + "cu_next_token_offset", + "cu_batch_token_offset", + "seq_lens_this_time", + "seq_lens_encoder"}) + .Outputs({"token_ids_out"}) + .SetInplaceMap({{"token_ids", "token_ids_out"}}) + .SetKernelFn(PD_KERNEL(SpeculateInsertFirstToken)); + +PD_BUILD_STATIC_OP(speculate_get_target_logits) + .Inputs({"target_logits", + "logits", + "cu_batch_token_offset", + "ori_cu_batch_token_offset", + "seq_lens_this_time", + "seq_lens_encoder", + "accept_num"}) + .Outputs({"target_logits_out"}) + .SetInplaceMap({{"target_logits", "target_logits_out"}}) + .SetKernelFn(PD_KERNEL(SpeculateGetTargetLogits)); diff --git a/custom_ops/gpu_ops/speculate_decoding/speculate_save_output_with_topk.cc b/custom_ops/gpu_ops/speculate_decoding/speculate_save_output_with_topk.cc new file mode 100644 index 0000000000..78eb6c1d48 --- /dev/null +++ b/custom_ops/gpu_ops/speculate_decoding/speculate_save_output_with_topk.cc @@ -0,0 +1,202 @@ +// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include +#include +#include +#include +#include "paddle/extension.h" + +#ifndef PD_BUILD_STATIC_OP +#define PD_BUILD_STATIC_OP(name) PD_BUILD_OP(static_op_##name) +#endif + +#define MAX_BSZ 512 +#define K 20 +#define MAX_DRAFT_TOKEN_NUM 6 + +struct batch_msgdata { + int tokens[MAX_DRAFT_TOKEN_NUM * (K + 1)]; + float scores[MAX_DRAFT_TOKEN_NUM * (K + 1)]; + int ranks[MAX_DRAFT_TOKEN_NUM]; +}; + +struct msgdata { + long mtype; + int meta[3 + MAX_BSZ]; // stop_flag, message_flag, bsz, batch_token_nums + batch_msgdata mtext[MAX_BSZ]; +}; + +void SpeculateSaveOutMmsgTopK(const paddle::Tensor& sampled_token_ids, + const paddle::Tensor& logprob_token_ids, + const paddle::Tensor& logprob_scores, + const paddle::Tensor& logprob_ranks, + const paddle::Tensor& token_num_per_batch, + const paddle::Tensor& cu_batch_token_offset, + const paddle::Tensor& not_need_stop, + int message_flag, // Target: 3, Draft: 4 + int64_t rank_id) { + if (rank_id > 0) { + return; + } + auto sampled_token_ids_cpu = + sampled_token_ids.copy_to(paddle::CPUPlace(), false); + auto logprob_token_ids_cpu = + logprob_token_ids.copy_to(paddle::CPUPlace(), false); + auto logprob_scores_cpu = logprob_scores.copy_to(paddle::CPUPlace(), false); + auto logprob_ranks_cpu = logprob_ranks.copy_to(paddle::CPUPlace(), false); + auto token_num_per_batch_cpu = + token_num_per_batch.copy_to(paddle::CPUPlace(), false); + auto cu_batch_token_offset_cpu = + cu_batch_token_offset.copy_to(paddle::CPUPlace(), false); + int64_t* sampled_token_ids_data = sampled_token_ids_cpu.data(); + int64_t* logprob_token_ids_data = logprob_token_ids_cpu.data(); + float* logprob_scores_data = logprob_scores_cpu.data(); + int64_t* logprob_ranks_data = logprob_ranks_cpu.data(); + int* token_num_per_batch_data = token_num_per_batch_cpu.data(); + int* cu_batch_token_offset_data = cu_batch_token_offset_cpu.data(); + + static struct msgdata msg_sed; + int msg_queue_id = 1; + if (const char* inference_msg_queue_id_env_p = + std::getenv("INFERENCE_MSG_QUEUE_ID")) { + std::string inference_msg_queue_id_env_str( + inference_msg_queue_id_env_p); + int inference_msg_queue_id_from_env = + std::stoi(inference_msg_queue_id_env_str); + msg_queue_id = inference_msg_queue_id_from_env; +#ifdef SPECULATE_SAVE_WITH_OUTPUT_DEBUG + std::cout << "Your INFERENCE_MSG_QUEUE_ID is: " + << inference_msg_queue_id_from_env << std::endl; +#endif + } else { +#ifdef SPECULATE_SAVE_WITH_OUTPUT_DEBUG + std::cout << "Failed to got INFERENCE_MSG_QUEUE_ID at env, use default." + << std::endl; +#endif + } + int inference_msg_id_from_env = 1; + if (const char* inference_msg_id_env_p = std::getenv("INFERENCE_MSG_ID")) { + std::string inference_msg_id_env_str(inference_msg_id_env_p); + inference_msg_id_from_env = std::stoi(inference_msg_id_env_str); + if (inference_msg_id_from_env == 2) { + // 2 and -2 is perserve for no-output indication. + throw std::runtime_error( + " INFERENCE_MSG_ID cannot be 2, please use other number."); + } + if (inference_msg_id_from_env < 0) { + throw std::runtime_error( + " INFERENCE_MSG_ID cannot be negative, please use other " + "number."); + } +#ifdef SPECULATE_SAVE_WITH_OUTPUT_DEBUG + std::cout << "Your INFERENCE_MSG_ID is: " << inference_msg_id_from_env + << std::endl; +#endif + } else { +#ifdef SPECULATE_SAVE_WITH_OUTPUT_DEBUG + std::cout + << "Failed to got INFERENCE_MSG_ID at env, use (int)1 as default." + << std::endl; +#endif + } + static key_t key = ftok("/dev/shm", msg_queue_id); + static int msgid = msgget(key, IPC_CREAT | 0666); +#ifdef SPECULATE_SAVE_WITH_OUTPUT_DEBUG + std::cout << "save_output_key: " << key << std::endl; + std::cout << "save msgid: " << msgid << std::endl; +#endif + msg_sed.mtype = 1; + msg_sed.meta[0] = not_need_stop.data()[0] + ? inference_msg_id_from_env + : -inference_msg_id_from_env; + msg_sed.meta[1] = message_flag; + int bsz = token_num_per_batch.shape()[0]; + msg_sed.meta[2] = bsz; + int max_num_logprobs = logprob_token_ids.shape()[1]; + for (int i = 0; i < bsz; i++) { + int cur_token_num = token_num_per_batch_data[i]; + msg_sed.meta[3 + i] = cur_token_num; + auto* cur_batch_msg_sed = &msg_sed.mtext[i]; + int token_offset = cu_batch_token_offset_data[i]; + for (int j = 0; j < cur_token_num; j++) { + auto* cur_tokens = &cur_batch_msg_sed->tokens[j * (K + 1)]; + auto* cur_scores = &cur_batch_msg_sed->scores[j * (K + 1)]; + for (int k = 0; k < K + 1; k++) { + if (k == 0) { + cur_tokens[k] = + (int)sampled_token_ids_data[token_offset + j]; + cur_scores[k] = + logprob_scores_data[(token_offset + j) * (K + 1) + k]; + } else if (k < max_num_logprobs) { + cur_tokens[k] = (int) + logprob_token_ids_data[(token_offset + j) * (K + 1) + + k]; + cur_scores[k] = + logprob_scores_data[(token_offset + j) * (K + 1) + k]; + } else { + cur_tokens[k] = -1; + cur_scores[k] = 0.0; + } + } + cur_batch_msg_sed->ranks[j] = + (int)logprob_ranks_data[token_offset + j]; + } + } +#ifdef SPECULATE_SAVE_WITH_OUTPUT_DEBUG + std::cout << "msg data: " << std::endl; + std::cout << "stop_flag: " << msg_sed.meta[0] + << ", message_flag: " << msg_sed.meta[1] + << ", bsz: " << msg_sed.meta[2] << std::endl; + for (int i = 0; i < bsz; i++) { + int cur_token_num = msg_sed.meta[3 + i]; + auto* cur_batch_msg_sed = &msg_sed.mtext[i]; + std::cout << "batch " << i << " token_num: " << cur_token_num + << std::endl; + for (int j = 0; j < cur_token_num; j++) { + auto* cur_tokens = &cur_batch_msg_sed->tokens[j * (K + 1)]; + auto* cur_scores = &cur_batch_msg_sed->scores[j * (K + 1)]; + std::cout << "tokens: "; + for (int k = 0; k < K + 1; k++) { + std::cout << cur_tokens[k] << " "; + } + std::cout << std::endl; + std::cout << "scores: "; + for (int k = 0; k < K + 1; k++) { + std::cout << cur_scores[k] << " "; + } + std::cout << std::endl; + std::cout << "ranks: " << cur_batch_msg_sed->ranks[j] << std::endl; + } + } + std::cout << std::endl; +#endif + if (msgsnd(msgid, &msg_sed, sizeof(msg_sed) - sizeof(long), 0) == -1) { + printf("full msg buffer\n"); + } +} + +PD_BUILD_STATIC_OP(speculate_save_output_topk) + .Inputs({ + "sampled_token_ids", + "logprob_token_ids", + "logprob_scores", + "logprob_ranks", + "token_num_per_batch", + "cu_batch_token_offset", + "not_need_stop", + }) + .Attrs({"message_flag: int", "rank_id: int64_t"}) + .SetKernelFn(PD_KERNEL(SpeculateSaveOutMmsgTopK)); diff --git a/fastdeploy/engine/args_utils.py b/fastdeploy/engine/args_utils.py index ba973742a4..2e8aabea57 100644 --- a/fastdeploy/engine/args_utils.py +++ b/fastdeploy/engine/args_utils.py @@ -403,8 +403,6 @@ def __post_init__(self): if self.dynamic_load_weight: self.enable_prefix_caching = False if self.enable_logprob: - if self.speculative_config is not None: - raise NotImplementedError("Logprob does not support speculation_config.") if not current_platform.is_cuda(): raise NotImplementedError("Only CUDA platform supports logprob.") if self.splitwise_role != "mixed": diff --git a/fastdeploy/engine/request.py b/fastdeploy/engine/request.py index 04a2276afb..9af8d76e80 100644 --- a/fastdeploy/engine/request.py +++ b/fastdeploy/engine/request.py @@ -287,6 +287,7 @@ class CompletionOutput: token_ids: list[int] logprob: Optional[float] = None top_logprobs: Optional[LogprobsLists] = None + draft_top_logprobs: Optional[LogprobsLists] = None logprobs: Optional[SampleLogprobs] = None draft_token_ids: list[int] = None text: Optional[str] = None @@ -303,6 +304,7 @@ def to_dict(self): "token_ids": self.token_ids, "logprob": self.logprob, "top_logprobs": self.top_logprobs, + "draft_top_logprobs": self.draft_top_logprobs, "logprobs": self.logprobs, "draft_token_ids": self.draft_token_ids, "text": self.text, @@ -328,6 +330,8 @@ def __repr__(self) -> str: f"draft_token_ids={self.draft_token_ids}, " f"reasoning_content={self.reasoning_content!r}, " f"logprobs={self.logprobs}, " + f"top_logprobs={self.top_logprobs}, " + f"draft_top_logprobs={self.draft_top_logprobs}, " ) @@ -412,6 +416,7 @@ def __init__( request_id: str, prompt: Optional[str] = None, prompt_token_ids: Optional[list[int]] = None, + output_type: Optional[int] = 3, outputs: CompletionOutput = None, finished: bool = False, metrics: Optional[RequestMetrics] = None, @@ -422,6 +427,7 @@ def __init__( self.request_id = request_id self.prompt = prompt self.prompt_token_ids = prompt_token_ids + self.output_type = output_type self.outputs = outputs self.finished = finished self.metrics = metrics @@ -450,12 +456,21 @@ def add(self, next_output: RequestOutput) -> None: self.outputs.top_logprobs.logprob_token_ids.extend(next_output.outputs.top_logprobs.logprob_token_ids) self.outputs.top_logprobs.logprobs.extend(next_output.outputs.top_logprobs.logprobs) self.outputs.top_logprobs.sampled_token_ranks.extend(next_output.outputs.top_logprobs.sampled_token_ranks) + if next_output.outputs.draft_top_logprobs is not None: + self.outputs.draft_top_logprobs.logprob_token_ids.extend( + next_output.outputs.draft_top_logprobs.logprob_token_ids + ) + self.outputs.draft_top_logprobs.logprobs.extend(next_output.outputs.draft_top_logprobs.logprobs) + self.outputs.draft_top_logprobs.sampled_token_ranks.extend( + next_output.outputs.draft_top_logprobs.sampled_token_ranks + ) def __repr__(self) -> str: return ( f"RequestOutput(request_id={self.request_id}, " f"prompt={self.prompt!r}, " f"prompt_token_ids={self.prompt_token_ids}, " + f"output_type={self.output_type}, " f"outputs={self.outputs}, " f"finished={self.finished}, " f"num_cached_tokens={self.num_cached_tokens}, " @@ -476,6 +491,7 @@ def to_dict(self): "request_id": self.request_id, "prompt": self.prompt, "prompt_token_ids": self.prompt_token_ids, + "output_type": self.output_type, "outputs": None if self.outputs is None else self.outputs.to_dict(), "metrics": None if self.metrics is None else self.metrics.to_dict(), "finished": self.finished, diff --git a/fastdeploy/entrypoints/openai/protocol.py b/fastdeploy/entrypoints/openai/protocol.py index b74e0ffb46..590a9a279b 100644 --- a/fastdeploy/entrypoints/openai/protocol.py +++ b/fastdeploy/entrypoints/openai/protocol.py @@ -184,6 +184,7 @@ class ChatCompletionResponseChoice(BaseModel): index: int message: ChatMessage logprobs: Optional[LogProbs] = None + draft_logprobs: Optional[LogProbs] = None finish_reason: Optional[Literal["stop", "length", "tool_calls", "recover_stop"]] @@ -246,6 +247,7 @@ class ChatCompletionResponseStreamChoice(BaseModel): index: int delta: DeltaMessage logprobs: Optional[LogProbs] = None + draft_logprobs: Optional[LogProbs] = None finish_reason: Optional[Literal["stop", "length", "tool_calls"]] = None arrival_time: Optional[float] = None @@ -278,6 +280,7 @@ class CompletionResponseChoice(BaseModel): completion_tokens: Optional[str] = None arrival_time: Optional[float] = None logprobs: Optional[CompletionLogprobs] = None + draft_logprobs: Optional[CompletionLogprobs] = None reasoning_content: Optional[str] = None finish_reason: Optional[Literal["stop", "length", "tool_calls"]] tool_calls: Optional[List[DeltaToolCall | ToolCall]] = None @@ -316,6 +319,7 @@ class CompletionResponseStreamChoice(BaseModel): text: str arrival_time: float = None logprobs: Optional[CompletionLogprobs] = None + draft_logprobs: Optional[CompletionLogprobs] = None prompt_token_ids: Optional[List[int]] = None completion_token_ids: Optional[List[int]] = None text_after_process: Optional[str] = None @@ -405,6 +409,7 @@ class CompletionRequest(BaseModel): echo: Optional[bool] = False frequency_penalty: Optional[float] = None logprobs: Optional[int] = None + include_draft_logprobs: Optional[bool] = False # For logits and logprobs post processing temp_scaled_logprobs: bool = False top_p_normalized_logprobs: bool = False @@ -540,6 +545,7 @@ class ChatCompletionRequest(BaseModel): frequency_penalty: Optional[float] = None logprobs: Optional[bool] = False top_logprobs: Optional[int] = 0 + include_draft_logprobs: Optional[bool] = False # For logits and logprobs post processing temp_scaled_logprobs: bool = False diff --git a/fastdeploy/entrypoints/openai/serving_chat.py b/fastdeploy/entrypoints/openai/serving_chat.py index 125d785fe3..d261a650ba 100644 --- a/fastdeploy/entrypoints/openai/serving_chat.py +++ b/fastdeploy/entrypoints/openai/serving_chat.py @@ -293,12 +293,18 @@ async def chat_completion_stream_generator( output = res["outputs"] output_top_logprobs = output["top_logprobs"] + output_draft_top_logprobs = output["draft_top_logprobs"] previous_num_tokens += len(output["token_ids"]) logprobs_res: Optional[LogProbs] = None + draft_logprobs_res: Optional[LogProbs] = None if request.logprobs and output_top_logprobs is not None: logprobs_res = self._create_chat_logprobs( output_top_logprobs, request.logprobs, request.top_logprobs ) + if request.include_draft_logprobs and output_draft_top_logprobs is not None: + draft_logprobs_res = self._create_chat_logprobs( + output_draft_top_logprobs, request.logprobs, request.top_logprobs + ) delta_message = DeltaMessage( reasoning_content="", @@ -326,6 +332,7 @@ async def chat_completion_stream_generator( index=0, delta=delta_message, logprobs=logprobs_res, + draft_logprobs=draft_logprobs_res, arrival_time=arrival_time, ) if res["finished"]: @@ -420,6 +427,7 @@ async def chat_completion_full_generator( previous_num_tokens = 0 current_waiting_time = 0 logprob_contents = [] + draft_logprob_contents = [] completion_token_ids = [] response_processor = ChatResponseProcessor( data_processor=self.engine_client.data_processor, @@ -460,12 +468,23 @@ async def chat_completion_full_generator( # The logprob for handling the response output = data["outputs"] output_top_logprobs = output["top_logprobs"] + output_draft_top_logprobs = output["draft_top_logprobs"] if output_top_logprobs is not None: + # logprobs logprobs_res = self._create_chat_logprobs( output_top_logprobs, request.logprobs, request.top_logprobs ) if logprobs_res and logprobs_res.content is not None: logprob_contents.extend(logprobs_res.content) + + # draf_logprobs + if request.include_draft_logprobs and output_draft_top_logprobs is not None: + draft_logprobs_res = self._create_chat_logprobs( + output_draft_top_logprobs, request.logprobs, request.top_logprobs + ) + if draft_logprobs_res and draft_logprobs_res.content is not None: + draft_logprob_contents.extend(draft_logprobs_res.content) + if data["finished"]: final_res = data task_is_finished = True @@ -499,11 +518,15 @@ async def chat_completion_full_generator( logprobs_full_res = None if logprob_contents: logprobs_full_res = LogProbs(content=logprob_contents) + draft_logprobs_full_res = None + if draft_logprob_contents: + draft_logprobs_full_res = LogProbs(content=draft_logprob_contents) choice = ChatCompletionResponseChoice( index=0, message=message, logprobs=logprobs_full_res, + draft_logprobs=draft_logprobs_full_res, finish_reason=None, ) has_no_token_limit = request.max_tokens is None and request.max_completion_tokens is None diff --git a/fastdeploy/entrypoints/openai/serving_completion.py b/fastdeploy/entrypoints/openai/serving_completion.py index 9b089d073d..dbc2aabf64 100644 --- a/fastdeploy/entrypoints/openai/serving_completion.py +++ b/fastdeploy/entrypoints/openai/serving_completion.py @@ -212,6 +212,7 @@ async def completion_full_generator( valid_results = [dict()] * num_choices output_tokens = [0] * num_choices aggregated_top_logprobs = [[[], [], []] for _ in range(num_choices)] + aggregated_draft_top_logprobs = [[[], [], []] for _ in range(num_choices)] aggregated_token_ids = [[] for _ in range(num_choices)] completion_batched_token_ids = [[] for _ in range(num_choices)] current_waiting_time = 0 @@ -238,12 +239,19 @@ async def completion_full_generator( raise ValueError("{}".format(data["error_msg"])) output = data["outputs"] - output_top_logprobs = output["top_logprobs"] + output_top_logprobs = output.get("top_logprobs") or None + output_draft_top_logprobs = output.get("draft_top_logprobs") or None if output_top_logprobs is not None: aggregated_top_logprobs[rid][0].extend(output_top_logprobs[0]) aggregated_top_logprobs[rid][1].extend(output_top_logprobs[1]) aggregated_top_logprobs[rid][2].extend(output_top_logprobs[2]) + # draft logprobs + if request.include_draft_logprobs and output_draft_top_logprobs is not None: + aggregated_draft_top_logprobs[rid][0].extend(output_draft_top_logprobs[0]) + aggregated_draft_top_logprobs[rid][1].extend(output_draft_top_logprobs[1]) + aggregated_draft_top_logprobs[rid][2].extend(output_draft_top_logprobs[2]) + aggregated_token_ids[rid].extend(data["outputs"]["token_ids"]) self.engine_client.data_processor.process_response_dict( @@ -254,6 +262,7 @@ async def completion_full_generator( if data.get("finished", False): data["output_token_ids"] = output_tokens[rid] data["outputs"]["top_logprobs"] = aggregated_top_logprobs[rid] + data["outputs"]["draft_top_logprobs"] = aggregated_draft_top_logprobs[rid] data["outputs"]["token_ids"] = aggregated_token_ids[rid] valid_results[rid] = data num_choices -= 1 @@ -390,10 +399,17 @@ async def completion_stream_generator( await self._echo_back_prompt(request, res, idx) output = res["outputs"] output_top_logprobs = output["top_logprobs"] + output_draft_top_logprobs = output["draft_top_logprobs"] logprobs_res: Optional[CompletionLogprobs] = None + draft_logprobs_res: Optional[CompletionLogprobs] = None if request.logprobs and output_top_logprobs is not None: logprobs_res = self._create_completion_logprobs(output_top_logprobs, request.logprobs, 0) + # draft logprobs + if request.include_draft_logprobs and output_draft_top_logprobs is not None: + draft_logprobs_res = self._create_completion_logprobs( + output_draft_top_logprobs, request.logprobs, 0 + ) output_tokens[idx] += 1 delta_message = CompletionResponseStreamChoice( index=idx, @@ -406,6 +422,7 @@ async def completion_stream_generator( reasoning_content="", arrival_time=arrival_time, logprobs=logprobs_res, + draft_logprobs=draft_logprobs_res, ) if not res["finished"] and "delta_message" in output: delta_message_output = output["delta_message"] @@ -493,12 +510,19 @@ def request_output_to_completion_response( completion_token_ids = completion_batched_token_ids[idx] output = final_res["outputs"] - output_top_logprobs = output["top_logprobs"] + output_top_logprobs = output.get("top_logprobs") or None + output_draft_top_logprobs = output.get("draft_top_logprobs") or None aggregated_logprobs: Optional[CompletionLogprobs] = None if output_top_logprobs is not None: aggregated_logprobs = self._create_completion_logprobs(output_top_logprobs, request.logprobs, 0) + aggregated_draft_logprobs: Optional[CompletionLogprobs] = None + if output_draft_top_logprobs is not None: + aggregated_draft_logprobs = self._create_completion_logprobs( + output_draft_top_logprobs, request.logprobs, 0 + ) + if request.echo: assert prompt_text is not None token_ids = [*prompt_token_ids, *output["token_ids"]] @@ -524,6 +548,7 @@ def request_output_to_completion_response( reasoning_content=output.get("reasoning_content"), tool_calls=output.get("tool_call"), logprobs=aggregated_logprobs, + draft_logprobs=aggregated_draft_logprobs, finish_reason=finish_reason, ) choices.append(choice_data) diff --git a/fastdeploy/model_executor/layers/sample/ops/__init__.py b/fastdeploy/model_executor/layers/sample/ops/__init__.py index 09834b305a..17952e96a4 100644 --- a/fastdeploy/model_executor/layers/sample/ops/__init__.py +++ b/fastdeploy/model_executor/layers/sample/ops/__init__.py @@ -18,6 +18,10 @@ apply_penalty_multi_scores, apply_speculative_penalty_multi_scores, ) +from .speculate_logprob_utils import ( + speculate_get_target_logits, + speculate_insert_first_token, +) from .top_k_top_p_sampling import min_p_sampling, top_k_top_p_sampling __all__ = [ @@ -25,4 +29,6 @@ "apply_speculative_penalty_multi_scores", "top_k_top_p_sampling", "min_p_sampling", + "speculate_get_target_logits", + "speculate_insert_first_token", ] diff --git a/fastdeploy/model_executor/layers/sample/ops/speculate_logprob_utils.py b/fastdeploy/model_executor/layers/sample/ops/speculate_logprob_utils.py new file mode 100644 index 0000000000..2caaf4892b --- /dev/null +++ b/fastdeploy/model_executor/layers/sample/ops/speculate_logprob_utils.py @@ -0,0 +1,72 @@ +""" +# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" + +import paddle + +from fastdeploy.platforms import current_platform + + +def speculate_get_target_logits( + target_logits: paddle.Tensor, + logits: paddle.Tensor, + cu_batch_token_offset: paddle.Tensor, + ori_cu_batch_token_offset: paddle.Tensor, + seq_lens_this_time: paddle.Tensor, + seq_lens_encoder: paddle.Tensor, + accept_num: paddle.Tensor, +): + """ + speculate_get_target_logits + """ + if current_platform.is_cuda(): + from fastdeploy.model_executor.ops.gpu import speculate_get_target_logits + + speculate_get_target_logits( + target_logits, + logits, + cu_batch_token_offset, + ori_cu_batch_token_offset, + seq_lens_this_time, + seq_lens_encoder, + accept_num, + ) + else: + raise NotImplementedError + + +def speculate_insert_first_token( + token_ids: paddle.Tensor, + accept_tokens: paddle.Tensor, + next_tokens: paddle.Tensor, + cu_next_token_offset: paddle.Tensor, + cu_batch_token_offset: paddle.Tensor, + seq_lens_this_time: paddle.Tensor, + seq_lens_encoder: paddle.Tensor, +): + if current_platform.is_cuda(): + from fastdeploy.model_executor.ops.gpu import speculate_insert_first_token + + speculate_insert_first_token( + token_ids, + accept_tokens, + next_tokens, + cu_next_token_offset, + cu_batch_token_offset, + seq_lens_this_time, + seq_lens_encoder, + ) + else: + raise NotImplementedError diff --git a/fastdeploy/model_executor/layers/sample/sampler.py b/fastdeploy/model_executor/layers/sample/sampler.py index 6a8db178fd..477acea217 100644 --- a/fastdeploy/model_executor/layers/sample/sampler.py +++ b/fastdeploy/model_executor/layers/sample/sampler.py @@ -34,6 +34,8 @@ apply_penalty_multi_scores, apply_speculative_penalty_multi_scores, min_p_sampling, + speculate_get_target_logits, + speculate_insert_first_token, top_k_top_p_sampling, ) from fastdeploy.platforms import current_platform @@ -382,6 +384,7 @@ def __init__(self, fd_config: FDConfig): self.speculative_verify_window = fd_config.speculative_config.verify_window self.speculative_max_candidate_len = fd_config.speculative_config.max_candidate_len self.speculative_benchmark_mode = fd_config.speculative_config.benchmark_mode + self.speculative_tokens_num = fd_config.speculative_config.num_speculative_tokens def pre_process(self, skip_idx_list: List[int] = []): """pre process before running""" @@ -396,6 +399,98 @@ def apply_logits_processor( """apply logits processor to sampler""" pass + def compute_logprobs( + self, + logits: paddle.Tensor, + sampling_metadata: SamplingMetadata, + ) -> paddle.Tensor: + """compute logprobs""" + share_inputs = sampling_metadata.share_inputs + last_logits = logits + real_bsz = share_inputs["seq_lens_this_time"].shape[0] + batch_token_num = share_inputs["batch_token_num"][:real_bsz] + + temp_scaled_logprobs = sampling_metadata.temp_scaled_logprobs + top_p_normalized_logprobs = sampling_metadata.top_p_normalized_logprobs + if temp_scaled_logprobs is not None: + real_bsz_temp_scaled = temp_scaled_logprobs[:real_bsz] + temperature = sampling_metadata.temperature[:real_bsz] + real_bsz_temp_scaled = ( + real_bsz_temp_scaled.astype("int32").squeeze(1).repeat_interleave(batch_token_num).astype("bool") + ) + temperature = temperature.squeeze(1).repeat_interleave(batch_token_num) + temp_temperature = paddle.where( + real_bsz_temp_scaled, temperature, paddle.ones_like(temperature) + ).unsqueeze(1) + last_logits = last_logits / temp_temperature + + last_logprobs = F.log_softmax(last_logits, axis=-1) + top_p_logprob = None + top_p_token_mask = None + + if top_p_normalized_logprobs is not None and share_inputs is not None: + real_token_top_p = ( + sampling_metadata.top_p[:real_bsz].squeeze(1).repeat_interleave(batch_token_num).unsqueeze(1) + ) + top_p_normalized_logprobs = ( + top_p_normalized_logprobs[:real_bsz] + .astype("int32") + .squeeze(1) + .repeat_interleave(batch_token_num) + .astype("bool") + .unsqueeze(1) + ) + top_p_token_mask = paddle.logical_and(top_p_normalized_logprobs, real_token_top_p != 1.0) + if top_p_token_mask.any(): + probs = F.softmax(last_logits, axis=-1) + probs = top_p_normalize_probs_paddle(probs, real_token_top_p) + top_p_logprob = paddle.log(probs) + if top_p_logprob is not None: + last_logprobs = paddle.where(top_p_token_mask, top_p_logprob, last_logprobs) + return last_logprobs + + def gather_logprobs( + self, + logprobs: paddle.Tensor, + num_logprobs: int, + token_ids: paddle.Tensor, + ) -> LogprobsTensors: + """ + Gather logprobs for topk and sampled/prompt token. + Args: + logprobs: (num tokens) x (vocab) tensor + num_logprobs: minimum number of logprobs to + retain per token + token_ids: prompt tokens (if prompt logprobs) + or sampled tokens (if sampled + logprobs); 1D token ID tensor + with (num tokens) elements + Must be int64. + Returns: + Top-k int indices tensor, (num tokens) x (num_logprobs + 1) + Top-k float logprobs tensor, (num tokens) x (num_logprobs + 1) + Sampled token rank tensor, (num tokens) + """ + assert token_ids.dtype == paddle.int64 + token_ids = token_ids.unsqueeze(1) + logprobs.clip_(min=paddle.finfo(logprobs.dtype).min) + # Get with the logprob of the prompt or sampled token. + token_logprobs = paddle.take_along_axis(logprobs, token_ids, axis=-1) + + # Compute the ranks of the actual token. + token_ranks = (logprobs >= token_logprobs).sum(-1) + + if num_logprobs >= 1: + # Find the topK values. + topk_logprobs, topk_indices = paddle.topk(logprobs, num_logprobs, axis=-1) + indices = paddle.concat([token_ids, topk_indices], axis=1) + top_logprobs = paddle.concat([token_logprobs, topk_logprobs], axis=1) + else: + indices = token_ids + top_logprobs = token_logprobs + + return LogprobsTensors(indices, top_logprobs, token_ranks) + def forward_cuda( self, logits: paddle.Tensor, @@ -472,7 +567,56 @@ def forward_cuda( line_break_id, ) - return None + num_logprobs = sampling_metadata.max_num_logprobs + batch_token_num = None + if num_logprobs is not None: + real_bsz = share_inputs["seq_lens_this_time"].shape[0] + batch_token_num = paddle.where( + share_inputs["seq_lens_encoder"][:real_bsz] != 0, + paddle.ones_like(share_inputs["seq_lens_encoder"][:real_bsz]), + share_inputs["accept_num"][:real_bsz].unsqueeze(1), + ).squeeze(1) + share_inputs["batch_token_num"] = batch_token_num + ori_cu_batch_token_offset = paddle.concat([paddle.to_tensor([0]), paddle.cumsum(batch_token_num)]).astype( + "int32" + ) + cu_batch_token_offset = paddle.concat( + [paddle.to_tensor([0]), paddle.cumsum(share_inputs["accept_num"][:real_bsz])] + ).astype("int32") + share_inputs["cu_batch_token_offset"] = cu_batch_token_offset + target_logtis = paddle.empty( + [share_inputs["accept_num"][:real_bsz].sum(), logits.shape[1]], dtype=logits.dtype + ) + speculate_get_target_logits( + target_logtis, + logits, + cu_batch_token_offset, + ori_cu_batch_token_offset, + share_inputs["seq_lens_this_time"], + share_inputs["seq_lens_encoder"], + share_inputs["accept_num"], + ) + raw_logprobs = self.compute_logprobs(target_logtis, sampling_metadata) + + logprobs_tensors = None + token_ids = share_inputs["accept_tokens"] + if num_logprobs is not None: + token_ids = paddle.concat( + [ + share_inputs["accept_tokens"][i, : share_inputs["accept_num"][i]] + for i in range(share_inputs["accept_num"][:real_bsz].shape[0]) + ] + ) + logprobs_tensors = self.gather_logprobs(raw_logprobs, num_logprobs, token_ids=token_ids) + + sampler_output = SamplerOutput( + sampled_token_ids=token_ids, + logprobs_tensors=logprobs_tensors, + token_num_per_batch=batch_token_num, + cu_batch_token_offset=share_inputs["cu_batch_token_offset"], + ) + + return sampler_output class MTPSampler(nn.Layer): @@ -485,6 +629,7 @@ def __init__(self, fd_config: FDConfig): self.forward = self.forward_cuda else: raise NotImplementedError + self.speculative_tokens_num = fd_config.speculative_config.num_speculative_tokens def pre_process(self, skip_idx_list: List[int] = []): """pre process before running""" @@ -499,6 +644,103 @@ def apply_logits_processor( """apply logits processor to sampler""" pass + def compute_logprobs( + self, + logits: paddle.Tensor, + sampling_metadata: SamplingMetadata, + ) -> paddle.Tensor: + """compute logprobs""" + share_inputs = sampling_metadata.share_inputs + real_bsz = share_inputs["seq_lens_this_time"].shape[0] + last_logits = logits + temp_scaled_logprobs = sampling_metadata.temp_scaled_logprobs + top_p_normalized_logprobs = sampling_metadata.top_p_normalized_logprobs + if temp_scaled_logprobs is not None: + real_bsz_temp_scaled = temp_scaled_logprobs[:real_bsz] + temperature = sampling_metadata.temperature[:real_bsz] + real_bsz_temp_scaled = ( + real_bsz_temp_scaled.astype("int32") + .squeeze(1) + .repeat_interleave(share_inputs["batch_token_num"][:real_bsz]) + .astype("bool") + ) + temperature = temperature.squeeze(1).repeat_interleave(share_inputs["batch_token_num"][:real_bsz]) + temp_temperature = paddle.where( + real_bsz_temp_scaled, temperature, paddle.ones_like(temperature) + ).unsqueeze(1) + last_logits = last_logits / temp_temperature + + last_logprobs = F.log_softmax(last_logits, axis=-1) + top_p_logprob = None + top_p_token_mask = None + + if top_p_normalized_logprobs is not None and share_inputs is not None: + real_token_top_p = ( + sampling_metadata.top_p[:real_bsz] + .squeeze(1) + .repeat_interleave(share_inputs["batch_token_num"][:real_bsz]) + .unsqueeze(1) + ) + top_p_normalized_logprobs = ( + top_p_normalized_logprobs[:real_bsz] + .astype("int32") + .squeeze(1) + .repeat_interleave(share_inputs["batch_token_num"][:real_bsz]) + .astype("bool") + .unsqueeze(1) + ) + top_p_token_mask = paddle.logical_and(top_p_normalized_logprobs, real_token_top_p != 1.0) + + if top_p_token_mask.any(): + probs = F.softmax(last_logits, axis=-1) + probs = top_p_normalize_probs_paddle(probs, real_token_top_p) + top_p_logprob = paddle.log(probs) + if top_p_logprob is not None: + last_logprobs = paddle.where(top_p_token_mask, top_p_logprob, last_logprobs) + return last_logprobs + + def gather_logprobs( + self, + logprobs: paddle.Tensor, + num_logprobs: int, + token_ids: paddle.Tensor, + ) -> LogprobsTensors: + """ + Gather logprobs for topk and sampled/prompt token. + Args: + logprobs: (num tokens) x (vocab) tensor + num_logprobs: minimum number of logprobs to + retain per token + token_ids: prompt tokens (if prompt logprobs) + or sampled tokens (if sampled + logprobs); 1D token ID tensor + with (num tokens) elements + Must be int64. + Returns: + Top-k int indices tensor, (num tokens) x (num_logprobs + 1) + Top-k float logprobs tensor, (num tokens) x (num_logprobs + 1) + Sampled token rank tensor, (num tokens) + """ + assert token_ids.dtype == paddle.int64 + token_ids = token_ids.unsqueeze(1) + logprobs.clip_(min=paddle.finfo(logprobs.dtype).min) + # Get with the logprob of the prompt or sampled token. + token_logprobs = paddle.take_along_axis(logprobs, token_ids, axis=-1) + + # Compute the ranks of the actual token. + token_ranks = (logprobs >= token_logprobs).sum(-1) + + if num_logprobs >= 1: + # Find the topK values. + topk_logprobs, topk_indices = paddle.topk(logprobs, num_logprobs, axis=-1) + indices = paddle.concat([token_ids, topk_indices], axis=1) + top_logprobs = paddle.concat([token_logprobs, topk_logprobs], axis=1) + else: + indices = token_ids + top_logprobs = token_logprobs + + return LogprobsTensors(indices, top_logprobs, token_ranks) + def forward_cuda( self, logits: paddle.Tensor, @@ -507,6 +749,12 @@ def forward_cuda( share_inputs: List[paddle.Tensor], ) -> paddle.Tensor: """ """ + num_logprobs = sampling_metadata.max_num_logprobs + real_bsz = share_inputs["seq_lens_this_time"].shape[0] + if num_logprobs is not None and share_inputs["substep"] == 0: + real_token_num = share_inputs["batch_token_num"][:real_bsz].sum() + raw_logprobs = self.compute_logprobs(share_inputs["draft_logits"][:real_token_num, :], sampling_metadata) + logits = apply_speculative_penalty_multi_scores( sampling_metadata.pre_token_ids, logits, @@ -528,4 +776,28 @@ def forward_cuda( _, next_tokens = top_k_top_p_sampling( probs, sampling_metadata.top_p, sampling_metadata.top_k, sampling_metadata.top_k_list ) - return next_tokens + + token_ids = None + logprobs_tensors = None + if num_logprobs is not None and share_inputs["substep"] == 0: + token_ids = paddle.empty(real_token_num, dtype="int64") + speculate_insert_first_token( + token_ids, + share_inputs["accept_tokens"], + next_tokens, + share_inputs["cu_next_token_offset"], + share_inputs["cu_batch_token_offset"], + share_inputs["seq_lens_this_time"], + share_inputs["seq_lens_encoder"], + ) + + logprobs_tensors = self.gather_logprobs(raw_logprobs, num_logprobs, token_ids=token_ids) + + sampler_output = SamplerOutput( + sampled_token_ids=token_ids, + logprobs_tensors=logprobs_tensors, + token_num_per_batch=share_inputs["batch_token_num"][:real_bsz], + cu_batch_token_offset=share_inputs["cu_batch_token_offset"], + ) + + return next_tokens, sampler_output diff --git a/fastdeploy/model_executor/pre_and_post_process.py b/fastdeploy/model_executor/pre_and_post_process.py index 184f05faba..d41b2d674b 100644 --- a/fastdeploy/model_executor/pre_and_post_process.py +++ b/fastdeploy/model_executor/pre_and_post_process.py @@ -64,6 +64,7 @@ speculate_get_padding_offset, speculate_get_seq_lens_output, speculate_save_output, + speculate_save_output_topk, speculate_set_value_by_flags_and_idx, speculate_step_paddle, speculate_step_system_cache, @@ -306,7 +307,10 @@ def post_process_normal( def post_process_specualate( - model_output: ModelOutputData, save_each_rank: bool = False, skip_save_output: bool = False + sampler_output: SamplerOutput, + model_output: ModelOutputData, + save_each_rank: bool = False, + skip_save_output: bool = False, ): """""" speculate_update( @@ -324,16 +328,29 @@ def post_process_specualate( ) if not skip_save_output: - speculate_save_output( - model_output.accept_tokens, - model_output.accept_num, - model_output.not_need_stop, - model_output.seq_lens_decoder, - model_output.prompt_lens, - model_output.mp_rank, - save_each_rank, - envs.ENABLE_V1_KVCACHE_SCHEDULER, - ) + if sampler_output.logprobs_tensors is None: + speculate_save_output( + model_output.accept_tokens, + model_output.accept_num, + model_output.not_need_stop, + model_output.seq_lens_decoder, + model_output.prompt_lens, + model_output.mp_rank, + save_each_rank, + envs.ENABLE_V1_KVCACHE_SCHEDULER, + ) + else: + speculate_save_output_topk( + sampler_output.sampled_token_ids, + sampler_output.logprobs_tensors.logprob_token_ids, + sampler_output.logprobs_tensors.logprobs, + sampler_output.logprobs_tensors.selected_token_ranks, + sampler_output.token_num_per_batch, + sampler_output.cu_batch_token_offset, + model_output.not_need_stop, + 3, # mtype + model_output.mp_rank, + ) # Update pre_ids through accept tokens @@ -360,7 +377,7 @@ def post_process( ) -> None: """Post-processing steps after completing a single token generation.""" if speculative_decoding: - post_process_specualate(model_output, save_each_rank, skip_save_output) + post_process_specualate(sampler_output, model_output, save_each_rank, skip_save_output) else: post_process_normal(sampler_output, model_output, share_inputs, block_size, save_each_rank, skip_save_output) @@ -529,6 +546,8 @@ def rebuild_padding( seq_lens_encoder: paddle.Tensor, output_padding_offset: Optional[paddle.Tensor] = None, max_input_length: Optional[int] = None, + first_token_out: Optional[paddle.Tensor] = None, + enable_logprob: Optional[bool] = False, ): """ Args: @@ -544,7 +563,9 @@ def rebuild_padding( seq_lens_decoder, seq_lens_encoder, output_padding_offset, + first_token_out, max_input_length, + enable_logprob, ) elif current_platform.is_dcu(): from fastdeploy.model_executor.ops.gpu import rebuild_padding diff --git a/fastdeploy/output/token_processor.py b/fastdeploy/output/token_processor.py index e0fdc41640..6fb7da8825 100644 --- a/fastdeploy/output/token_processor.py +++ b/fastdeploy/output/token_processor.py @@ -22,6 +22,7 @@ import weakref from collections import Counter from concurrent.futures import ThreadPoolExecutor +from typing import List import numpy as np @@ -60,11 +61,20 @@ def __init__(self, cfg, cached_generated_tokens, engine_worker_queue, split_conn self.use_logprobs = self.cfg.model_config.enable_logprob if self.speculative_decoding: - self.output_tokens = paddle.full( - shape=[SPECULATE_MAX_BSZ * MAX_DRAFT_TOKENS + SPECULATE_MAX_BSZ + 2], - fill_value=2, - dtype="int64", - ) + if self.use_logprobs: + self.output_tokens = paddle.full( + shape=[MAX_BSZ * MAX_DRAFT_TOKENS * (K + 1) + MAX_BSZ + 3, 1], fill_value=2, dtype="int64" + ) + self.output_scores = paddle.full( + shape=[MAX_BSZ * MAX_DRAFT_TOKENS * (K + 1), 1], fill_value=0.0, dtype="float32" + ) + self.output_ranks = paddle.full(shape=[MAX_BSZ * MAX_DRAFT_TOKENS], fill_value=0, dtype="int64") + else: + self.output_tokens = paddle.full( + shape=[SPECULATE_MAX_BSZ * MAX_DRAFT_TOKENS + SPECULATE_MAX_BSZ + 2], + fill_value=2, + dtype="int64", + ) elif self.use_logprobs: self.output_tokens = paddle.full(shape=[MAX_BSZ * (K + 1) + 2, 1], fill_value=2, dtype="int64") self.output_scores = paddle.full(shape=[MAX_BSZ * (K + 1), 1], fill_value=0.0, dtype="float32") @@ -100,6 +110,7 @@ def __init__(self, cfg, cached_generated_tokens, engine_worker_queue, split_conn self.executor = ThreadPoolExecutor(max_workers=1) self.prefill_result_status = dict() self._finalizer = weakref.finalize(self, self._cleanup_resources) + self._batch_result_buffer = None def _cleanup_resources(self): """Cleaning up shared memory resources""" @@ -149,6 +160,7 @@ def process_sampling_results(self): get_output_ep, get_output_topk, speculate_get_output, + speculate_get_output_topk, ) rank_id = self.cfg.parallel_config.local_data_parallel_id @@ -156,16 +168,27 @@ def process_sampling_results(self): try: is_blocking = True if self.speculative_decoding: - if ( - self.cfg.parallel_config.enable_expert_parallel - and self.cfg.parallel_config.data_parallel_size > 1 - ): - speculate_get_output(self.output_tokens, rank_id, is_blocking, True) + if self.use_logprobs: + speculate_get_output_topk( + self.output_tokens, + self.output_scores, + self.output_ranks, + K, + rank_id, + is_blocking, + ) + if self.output_tokens[0, 0] == -2: + continue else: - speculate_get_output(self.output_tokens, rank_id, is_blocking, False) - if self.output_tokens[0] == -2: - continue - + if ( + self.cfg.parallel_config.enable_expert_parallel + and self.cfg.parallel_config.data_parallel_size > 1 + ): + speculate_get_output(self.output_tokens, rank_id, is_blocking, True) + else: + speculate_get_output(self.output_tokens, rank_id, is_blocking, False) + if self.output_tokens[0] == -2: + continue else: if self.use_logprobs: get_output_topk( @@ -210,7 +233,7 @@ def process_metrics(): self.executor.submit(process_metrics) - def postprocess(self, batch_result): + def postprocess(self, batch_result: List[RequestOutput], mtype=3): """ single post-processing function @@ -218,7 +241,28 @@ def postprocess(self, batch_result): batch_result (list): batch results """ try: - self.cached_generated_tokens.put_results(batch_result) + if self.cfg.speculative_config.method and self.use_logprobs: + if mtype == 3: # target + finished_batch_result, unfinished_batch_result = [], [] + for r in batch_result: + (finished_batch_result if r.finished else unfinished_batch_result).append(r) + if finished_batch_result: + self.cached_generated_tokens.put_results(batch_result) + else: + self._batch_result_buffer = unfinished_batch_result + elif mtype == 4: # draft + target_batch_result = [] + draft_batch_result = batch_result + if self._batch_result_buffer is not None: + for target, decode in zip(self._batch_result_buffer, draft_batch_result): + target.outputs.draft_top_logprobs = decode.outputs.draft_top_logprobs + target_batch_result.append(target) + self._batch_result_buffer = None + self.cached_generated_tokens.put_results(target_batch_result) + else: + self.cached_generated_tokens.put_results(batch_result) + else: + self.cached_generated_tokens.put_results(batch_result) except Exception as e: llm_logger.error(f"Error in TokenProcessor's postprocess: {e}, {str(traceback.format_exc())}") @@ -299,9 +343,25 @@ def _process_batch_output(self): tokens = self.output_tokens.numpy() scores = None ranks = None + # target:3, draft:4 + mtype = 3 if self.cfg.speculative_config.method: - batch = self.output_tokens[1] - accept_num = tokens[2 : batch + 2] + if self.use_logprobs: + mtype = int(self.output_tokens[1, 0].item()) + batch = self.output_tokens[2, 0] + accept_num = [int(num[0]) for num in self.output_tokens[3 : batch + 3]] + tokens = tokens[3 + MAX_BSZ : 3 + MAX_BSZ + batch * MAX_DRAFT_TOKENS * (K + 1)].reshape( + [batch, MAX_DRAFT_TOKENS, K + 1] + ) + scores = ( + self.output_scores[: batch * MAX_DRAFT_TOKENS * (K + 1)] + .numpy() + .reshape([batch, MAX_DRAFT_TOKENS, K + 1]) + ) + ranks = self.output_ranks[: batch * MAX_DRAFT_TOKENS].numpy().reshape([batch, MAX_DRAFT_TOKENS]) + else: + batch = self.output_tokens[1] + accept_num = tokens[2 : batch + 2] self._record_speculative_decoding_mertics(accept_num) elif self.use_logprobs: batch = self.output_tokens[1, 0] @@ -329,19 +389,24 @@ def _process_batch_output(self): task_id = task.request_id if self.cfg.speculative_config.method: - token_ids = tokens[ - 2 - + SPECULATE_MAX_BSZ - + i * MAX_DRAFT_TOKENS : 2 - + SPECULATE_MAX_BSZ - + i * MAX_DRAFT_TOKENS - + accept_num[i] - ].tolist() - if len(token_ids) == 0 or token_ids[-1] <= 0: - if envs.ENABLE_V1_KVCACHE_SCHEDULER: - if task_id in self.resource_manager.to_be_rescheduled_request_id_set: - self.resource_manager.reschedule_preempt_task(task_id) - continue + if accept_num[i] == -3: + recovery_stop = True + if recovery_stop: + llm_logger.info(f"recovery stop signal found at task {task_id}") + token_ids = [RECOVERY_STOP_SIGNAL] + elif self.use_logprobs: + token_ids = tokens[i][:, 0].tolist()[: accept_num[i]] + else: + token_ids = tokens[ + 2 + + SPECULATE_MAX_BSZ + + i * MAX_DRAFT_TOKENS : 2 + + SPECULATE_MAX_BSZ + + i * MAX_DRAFT_TOKENS + + accept_num[i] + ].tolist() + if (not recovery_stop) and (len(token_ids) == 0 or token_ids[-1] <= 0): + continue else: token_id = int(tokens[i, 0]) token_ids = [token_id] @@ -384,6 +449,7 @@ def _process_batch_output(self): self._record_metrics(task, current_time, token_ids) result = RequestOutput( request_id=task_id, + output_type=mtype, outputs=CompletionOutput( index=i, send_idx=self.tokens_counter[task_id], @@ -403,28 +469,53 @@ def _process_batch_output(self): if is_prefill and len(token_ids) > 1: result.outputs.draft_token_ids = copy.deepcopy(token_ids) - for token_id in token_ids: + for batch_token_index in range(len(token_ids)): + token_id = token_ids[batch_token_index] self.tokens_counter[task_id] += 1 if token_id != RECOVERY_STOP_SIGNAL: result.outputs.token_ids.append(token_id) task.output_token_ids.append(token_id) if self.use_logprobs: - result.outputs.logprob = float(scores[i, 0]) - # Construct top_logprobs - topk_token_ids = tokens[i, :].tolist() - topk_logprobs = scores[i, :].tolist() - sampled_rank = ranks[i].item() - result.outputs.top_logprobs = LogprobsLists( - logprob_token_ids=[topk_token_ids], - logprobs=[topk_logprobs], - sampled_token_ranks=[sampled_rank], - ) - if token_id in task.eos_token_ids or is_prefill or recovery_stop: + if self.cfg.speculative_config.method: + result.outputs.logprob = float(scores[i, batch_token_index, 0]) + topk_token_ids = tokens[i, batch_token_index, :].tolist() + topk_logprobs = scores[i, batch_token_index, :].tolist() + sampled_rank = ranks[i, batch_token_index].item() + else: + result.outputs.logprob = float(scores[i, 0]) + topk_token_ids = tokens[i, :].tolist() + topk_logprobs = scores[i, :].tolist() + sampled_rank = ranks[i].item() + + if mtype == 3: # top_logprobs + if result.outputs.top_logprobs is None: + result.outputs.top_logprobs = LogprobsLists( + logprob_token_ids=[topk_token_ids], + logprobs=[topk_logprobs], + sampled_token_ranks=[sampled_rank], + ) + else: + result.outputs.top_logprobs.logprob_token_ids.extend([topk_token_ids]) + result.outputs.top_logprobs.logprobs.extend([topk_logprobs]) + result.outputs.top_logprobs.sampled_token_ranks.extend([sampled_rank]) + elif mtype == 4: # draft_top_logprobs + if result.outputs.draft_top_logprobs is None: + result.outputs.draft_top_logprobs = LogprobsLists( + logprob_token_ids=[topk_token_ids], + logprobs=[topk_logprobs], + sampled_token_ranks=[sampled_rank], + ) + else: + result.outputs.draft_top_logprobs.logprob_token_ids.extend([topk_token_ids]) + result.outputs.draft_top_logprobs.logprobs.extend([topk_logprobs]) + result.outputs.draft_top_logprobs.sampled_token_ranks.extend([sampled_rank]) + if mtype == 3 and (token_id in task.eos_token_ids or is_prefill or recovery_stop): result.finished = True if recovery_stop: result.error_msg = "Recover is not supported, the result is incomplete!" llm_logger.info( - f"Request: {task_id} finished, number of " f"generated tokens: {self.tokens_counter[task_id]}." + f"Request: {task_id} finished, number of " + f"generated tokens: {self.tokens_counter[task_id]}, token_id:{token_id},is_prefill:{is_prefill},recovery_stop:{recovery_stop}" ) llm_logger.info( f"Request: {task_id} token ratio: {self.tokens_counter[task_id] / (time.time() - task.inference_start_time)}" @@ -436,10 +527,14 @@ def _process_batch_output(self): self._record_completion_metrics(task, current_time) self._recycle_resources(task_id, i, task, result, is_prefill) break - if not is_prefill or self.cfg.scheduler_config.name == "splitwise": + if ( + not is_prefill + or self.cfg.scheduler_config.name == "splitwise" + or self.cfg.scheduler_config.name == "dp" + ): batch_result.append(result) - self.postprocess(batch_result) + self.postprocess(batch_result, mtype) def _record_metrics(self, task, current_time, token_ids): """Record all metrics for a task""" diff --git a/fastdeploy/spec_decode/mtp.py b/fastdeploy/spec_decode/mtp.py index fb7d326450..eec60756b7 100644 --- a/fastdeploy/spec_decode/mtp.py +++ b/fastdeploy/spec_decode/mtp.py @@ -44,6 +44,8 @@ mtp_save_first_token, mtp_step_paddle, share_external_data, + speculate_get_logits, + speculate_save_output_topk, ) from fastdeploy.model_executor.pre_and_post_process import pre_process, rebuild_padding @@ -72,6 +74,7 @@ def __init__( self.target_model_inputs = target_model_inputs self.mtp_strategy = self.speculative_config.mtp_strategy self.hybrid_mode = self.mtp_strategy == "with_ngram" and self.max_draft_token_num > self.num_model_steps + self.enable_logprob = self.model_config.enable_logprob # [mixed, prefill, decoder] self.role = "mixed" @@ -393,6 +396,22 @@ def _init_model_inputs(self): self.target_model_inputs["seq_lens_this_time"], fill_value=-1, dtype="int32" ) self.input_ids_len = paddle.zeros(shape=[self.max_num_seqs, 1], dtype="int64").cpu() + self.model_inputs["temp_scaled_logprobs"] = self.target_model_inputs["temp_scaled_logprobs"] + self.model_inputs["top_p_normalized_logprobs"] = self.target_model_inputs["top_p_normalized_logprobs"] + self.model_inputs["accept_num"] = self.target_model_inputs["accept_num"] + self.model_inputs["accept_tokens"] = self.target_model_inputs["accept_tokens"] + self.model_inputs["draft_logits"] = self.target_model_inputs["draft_logits"] + self.model_inputs["first_token_hidden_states"] = paddle.full( + [self.max_num_seqs, self.model_config.hidden_size], -1 + ) + self.model_inputs["batch_token_num"] = paddle.full(shape=[self.max_num_seqs], fill_value=0, dtype="int32") + self.model_inputs["next_token_num"] = paddle.full(shape=[self.max_num_seqs], fill_value=0, dtype="int32") + self.model_inputs["cu_batch_token_offset"] = paddle.full_like( + self.target_model_inputs["cu_batch_token_offset"], fill_value=0, dtype="int32" + ) + self.model_inputs["cu_next_token_offset"] = paddle.full( + shape=[self.max_num_seqs + 1], fill_value=0, dtype="int32" + ) def insert_tasks_v1(self, req_dicts: List[Request], num_running_requests: int): @@ -723,6 +742,10 @@ def _propose(self): min_dec_lens=self.model_inputs["min_dec_len"], bad_words_token_ids=self.model_inputs["bad_tokens"], eos_token_ids=self.model_inputs["eos_token_id"], + max_num_logprobs=20 if self.enable_logprob else None, + temp_scaled_logprobs=self.model_inputs["temp_scaled_logprobs"], + top_p_normalized_logprobs=self.model_inputs["top_p_normalized_logprobs"], + share_inputs=self.model_inputs, ) if self.num_model_steps > 1: @@ -744,18 +767,48 @@ def _propose(self): self.model_inputs["seq_lens_encoder"], self.model_inputs["output_padding_offset"], self.parallel_config.max_model_len, + self.model_inputs["first_token_hidden_states"], + self.enable_logprob if substep == 0 else False, ) # 4. Compute logits, Sample logits = self.model.compute_logits(hidden_states) + if self.enable_logprob and substep == 0: + first_token_logits = self.model.compute_logits(self.model_inputs["first_token_hidden_states"]) + + speculate_get_logits( + self.model_inputs["draft_logits"], + self.model_inputs["next_token_num"], + self.model_inputs["batch_token_num"], + self.model_inputs["cu_next_token_offset"], + self.model_inputs["cu_batch_token_offset"], + logits, + first_token_logits, + self.model_inputs["seq_lens_this_time"], + self.model_inputs["seq_lens_encoder"], + ) - sampled_token_ids = self.sampler( + sampled_token_ids, sampler_output = self.sampler( logits, self.sampling_metadata, self.max_model_len, self.model_inputs, ) + if substep == 0 and sampler_output.logprobs_tensors is not None: + real_bsz = self.model_inputs["seq_lens_this_time"].shape[0] + speculate_save_output_topk( + sampler_output.sampled_token_ids, + sampler_output.logprobs_tensors.logprob_token_ids, + sampler_output.logprobs_tensors.logprobs, + sampler_output.logprobs_tensors.selected_token_ranks, + self.model_inputs["batch_token_num"][:real_bsz], + self.model_inputs["cu_batch_token_offset"][:real_bsz], + self.model_inputs["not_need_stop"], + 4, # mtype + self.local_rank, + ) + if self.parallel_config.tensor_parallel_size > 1: paddle.distributed.broadcast( sampled_token_ids, diff --git a/fastdeploy/worker/gpu_model_runner.py b/fastdeploy/worker/gpu_model_runner.py index 6e281fd5cc..5050c49943 100644 --- a/fastdeploy/worker/gpu_model_runner.py +++ b/fastdeploy/worker/gpu_model_runner.py @@ -812,6 +812,15 @@ def _init_share_inputs(self, max_num_seqs: int): dtype="int64", ) self.share_inputs["step_seq_lens_this_time"] = paddle.full([max_num_seqs, 1], 0, dtype="int32") + # For MTP Logprob + self.share_inputs["draft_logits"] = paddle.full( + [max_num_seqs * (self.speculative_config.num_speculative_tokens + 1), self.model_config.vocab_size], + -1, + dtype="float32", + ) + self.share_inputs["cu_batch_token_offset"] = paddle.full( + shape=[max_num_seqs + 1], fill_value=0, dtype="int32" + ) if self.enable_mm: head_dim = self.model_config.head_dim @@ -1520,13 +1529,12 @@ class at the server level, which is too granular for ModelRunner. ) else: - self.sampler( + sampler_output = self.sampler( logits, self.sampling_metadata, self.parallel_config.max_model_len, self.share_inputs, ) - sampler_output = None if self.parallel_config.tensor_parallel_size > 1: paddle.distributed.broadcast( self.share_inputs["accept_tokens"], diff --git a/fastdeploy/worker/output.py b/fastdeploy/worker/output.py index 2fa348634c..1128062814 100644 --- a/fastdeploy/worker/output.py +++ b/fastdeploy/worker/output.py @@ -106,6 +106,8 @@ class SamplerOutput: # PLACEHOLDER_TOKEN_ID (-1 by default) is used for padding. sampled_token_ids: paddle.Tensor logprobs_tensors: Optional[LogprobsTensors] + token_num_per_batch: Optional[paddle.Tensor] = None + cu_batch_token_offset: Optional[paddle.Tensor] = None @dataclass