diff --git a/custom_ops/gpu_ops/append_attn/get_block_shape_and_split_kv_block.cu b/custom_ops/gpu_ops/append_attn/get_block_shape_and_split_kv_block.cu index 169be85246..4fc43e34fa 100644 --- a/custom_ops/gpu_ops/append_attn/get_block_shape_and_split_kv_block.cu +++ b/custom_ops/gpu_ops/append_attn/get_block_shape_and_split_kv_block.cu @@ -14,7 +14,9 @@ #include "cute/tensor.hpp" #include "helper.h" #include "paddle/extension.h" +#ifndef PADDLE_WITH_CUSTOM_DEVICE_METAX_GPU #include "paddle/phi/core/memory/memcpy.h" +#endif #include "utils.cuh" template diff --git a/custom_ops/gpu_ops/append_attn/mem_util.cuh b/custom_ops/gpu_ops/append_attn/mem_util.cuh index 89b65992d6..fb735be7ae 100644 --- a/custom_ops/gpu_ops/append_attn/mem_util.cuh +++ b/custom_ops/gpu_ops/append_attn/mem_util.cuh @@ -15,6 +15,7 @@ #include #include +#include enum class SharedMemFillMode { kFillZero, kNoFill }; @@ -42,18 +43,35 @@ __device__ __forceinline__ void ldmatrix_m8n8x4_trans_impl(uint32_t* R, } __device__ __forceinline__ void commit_group() { +#ifdef PADDLE_WITH_CUSTOM_DEVICE_METAX_GPU + {} +#else asm volatile("cp.async.commit_group;\n" ::); +#endif } template __device__ __forceinline__ void wait_group() { +#ifdef PADDLE_WITH_CUSTOM_DEVICE_METAX_GPU + cooperative_groups::wait(cooperative_groups::this_thread_block()); +#else asm volatile("cp.async.wait_group %0;\n" ::"n"(n)); +#endif } template __device__ __forceinline__ void load_128b(T* smem_ptr, const T* gmem_ptr) { uint32_t smem_int_ptr = static_cast(__cvta_generic_to_shared(smem_ptr)); +#ifdef PADDLE_WITH_CUSTOM_DEVICE_METAX_GPU + if constexpr (prefetch_mode == PrefetchMode::kPrefetch) { + memset(__cvta_shared_to_generic(smem_int_ptr), 0, 16); + memcpy(__cvta_shared_to_generic(smem_int_ptr), (void *)gmem_ptr, 16); + } else { + memset(__cvta_shared_to_generic(smem_int_ptr), 0, 16); + memcpy(__cvta_shared_to_generic(smem_int_ptr), (void *)gmem_ptr, 16); + } +#else if constexpr (prefetch_mode == PrefetchMode::kPrefetch) { asm volatile( "cp.async.cg.shared.global.L2::128B [%0], [%1], %2, %3;\n" ::"r"( @@ -68,6 +86,7 @@ __device__ __forceinline__ void load_128b(T* smem_ptr, const T* gmem_ptr) { "n"(16), "r"(16)); } +#endif } template @@ -76,6 +95,28 @@ __device__ __forceinline__ void pred_load_128b(T* smem_ptr, bool predicate) { uint32_t smem_int_ptr = static_cast(__cvta_generic_to_shared(smem_ptr)); +#ifdef PADDLE_WITH_CUSTOM_DEVICE_METAX_GPU + if constexpr (fill_mode == SharedMemFillMode::kFillZero) { + int src_in_bytes = predicate ? 16 : 0; + if constexpr (prefetch_mode == PrefetchMode::kPrefetch) { + memset(__cvta_shared_to_generic(smem_int_ptr), 0, 16); + memcpy(__cvta_shared_to_generic(smem_int_ptr), (void *)gmem_ptr, src_in_bytes); + } else { + memset(__cvta_shared_to_generic(smem_int_ptr), 0, 16); + memcpy(__cvta_shared_to_generic(smem_int_ptr), (void *)gmem_ptr, src_in_bytes); + } + } else { + if constexpr (prefetch_mode == PrefetchMode::kPrefetch) { + if (predicate) { + memcpy(__cvta_shared_to_generic(smem_int_ptr), (void *)gmem_ptr, 16); + } + } else { + if (predicate) { + memcpy(__cvta_shared_to_generic(smem_int_ptr), (void *)gmem_ptr, 16); + } + } + } +#else if constexpr (fill_mode == SharedMemFillMode::kFillZero) { int src_in_bytes = predicate ? 16 : 0; if constexpr (prefetch_mode == PrefetchMode::kPrefetch) { @@ -115,6 +156,7 @@ __device__ __forceinline__ void pred_load_128b(T* smem_ptr, "n"(16)); } } +#endif } template @@ -123,6 +165,17 @@ __device__ __forceinline__ void pred_load_64b(T* smem_ptr, bool predicate) { uint32_t smem_int_ptr = static_cast(__cvta_generic_to_shared(smem_ptr)); +#ifdef PADDLE_WITH_CUSTOM_DEVICE_METAX_GPU + if constexpr (fill_mode == SharedMemFillMode::kFillZero) { + int src_in_bytes = predicate ? 8 : 0; + memset(__cvta_shared_to_generic(smem_int_ptr), 0, 8); + memcpy(__cvta_shared_to_generic(smem_int_ptr), (void *)gmem_ptr, src_in_bytes); + } else { + if (predicate) { + memcpy(__cvta_shared_to_generic(smem_int_ptr), (void *)gmem_ptr, 8); + } + } +#else if constexpr (fill_mode == SharedMemFillMode::kFillZero) { int src_in_bytes = predicate ? 8 : 0; asm volatile( @@ -141,6 +194,7 @@ __device__ __forceinline__ void pred_load_64b(T* smem_ptr, "l"(gmem_ptr), "n"(8)); } +#endif } template @@ -149,6 +203,17 @@ __device__ __forceinline__ void pred_load_32b(T* smem_ptr, bool predicate) { uint32_t smem_int_ptr = static_cast(__cvta_generic_to_shared(smem_ptr)); +#ifdef PADDLE_WITH_CUSTOM_DEVICE_METAX_GPU + if constexpr (fill_mode == SharedMemFillMode::kFillZero) { + int src_in_bytes = predicate ? 4 : 0; + memset(__cvta_shared_to_generic(smem_int_ptr), 0, 4); + memcpy(__cvta_shared_to_generic(smem_int_ptr), (void *)gmem_ptr, src_in_bytes); + } else { + if (predicate) { + memcpy(__cvta_shared_to_generic(smem_int_ptr), (void *)gmem_ptr, 4); + } + } +#else if constexpr (fill_mode == SharedMemFillMode::kFillZero) { int src_in_bytes = predicate ? 4 : 0; asm volatile( @@ -167,6 +232,7 @@ __device__ __forceinline__ void pred_load_32b(T* smem_ptr, "l"(gmem_ptr), "n"(4)); } +#endif } template diff --git a/custom_ops/gpu_ops/helper.h b/custom_ops/gpu_ops/helper.h index 6f6554f03f..5be2dd2689 100644 --- a/custom_ops/gpu_ops/helper.h +++ b/custom_ops/gpu_ops/helper.h @@ -592,10 +592,13 @@ inline int get_cuda_max_shared_memory_per_block_opt_in(int const device) { #endif inline int GetSMVersion() { +#ifdef PADDLE_WITH_CUSTOM_DEVICE_METAX_GPU + return 80; +#else static int sm_version = phi::backends::gpu::GetGPUComputeCapability( phi::backends::gpu::GetCurrentDeviceId()); return sm_version; - +#endif } inline bool GetMlaUseTensorcore() { diff --git a/custom_ops/gpu_ops/noauxtc_kernel.h b/custom_ops/gpu_ops/noauxtc_kernel.h index 392dbfe3b1..e83d3e50ba 100644 --- a/custom_ops/gpu_ops/noauxtc_kernel.h +++ b/custom_ops/gpu_ops/noauxtc_kernel.h @@ -18,6 +18,7 @@ #include #include #include "helper.h" +#include namespace cg = cooperative_groups; @@ -601,7 +602,7 @@ __global__ void group_idx_and_topk_idx_kernel( if (i < topk) { s_topk_value[i] = value; } - topk_sum += reduce(tile, cuda_cast(value), cg::plus()); + topk_sum += cg::reduce(tile, cuda_cast(value), cg::plus()); } } @@ -658,6 +659,11 @@ void invokeNoAuxTc(T* scores, cudaStream_t const stream) { int64_t num_cases = num_tokens * n_group; int64_t topk_with_k2_num_blocks = (num_cases - 1) / NUM_WARPS_PER_BLOCK + 1; + +#ifdef PADDLE_WITH_CUSTOM_DEVICE_METAX_GPU + topk_with_k2_kernel<<>>( + group_scores, scores_with_bias, num_tokens, num_cases, n_group, num_experts / n_group); +#else auto* kernel_instance1 = &topk_with_k2_kernel; cudaLaunchConfig_t config; config.gridDim = topk_with_k2_num_blocks; @@ -671,6 +677,7 @@ void invokeNoAuxTc(T* scores, config.attrs = attrs; cudaLaunchKernelEx(&config, kernel_instance1, group_scores, scores_with_bias, num_tokens, num_cases, n_group, num_experts / n_group); +#endif int64_t topk_with_k_group_num_blocks = (num_tokens - 1) / NUM_WARPS_PER_BLOCK + 1; @@ -678,6 +685,12 @@ void invokeNoAuxTc(T* scores, warp_topk::calc_smem_size_for_block_wide(NUM_WARPS_PER_BLOCK, topk); +#ifdef PADDLE_WITH_CUSTOM_DEVICE_METAX_GPU + group_idx_and_topk_idx_kernel<<>>( + scores, group_scores, topk_values, topk_indices, scores_with_bias, + num_tokens, n_group, topk_group, topk, num_experts, num_experts / n_group, + renormalize, routed_scaling_factor); +#else auto* kernel_instance2 = &group_idx_and_topk_idx_kernel; config.gridDim = topk_with_k_group_num_blocks; config.blockDim = BLOCK_SIZE; @@ -691,6 +704,7 @@ void invokeNoAuxTc(T* scores, topk_values, topk_indices, scores_with_bias, num_tokens, n_group, topk_group, topk, num_experts, num_experts / n_group, renormalize, routed_scaling_factor); +#endif } #define INSTANTIATE_NOAUX_TC(T, IdxT) \ diff --git a/custom_ops/setup_ops.py b/custom_ops/setup_ops.py index d1d06e9c2f..b2df172d8a 100644 --- a/custom_ops/setup_ops.py +++ b/custom_ops/setup_ops.py @@ -600,9 +600,16 @@ def find_end_files(directory, end_str): "gpu_ops/read_data_ipc.cu", "gpu_ops/dequant_int8.cu", "gpu_ops/share_external_data.cu", + "gpu_ops/recover_decode_task.cu", + "gpu_ops/noaux_tc.cu", + "gpu_ops/fused_rotary_position_encoding.cu", + "gpu_ops/text_image_gather_scatter.cu", + "gpu_ops/text_image_index_out.cu", + "gpu_ops/get_position_ids_and_mask_encoder_batch.cu", + "gpu_ops/append_attn/mla_cache_kernel.cu", + "gpu_ops/append_attn/get_block_shape_and_split_kv_block.cu", "gpu_ops/moe/tritonmoe_preprocess.cu", "gpu_ops/moe/moe_topk_select.cu", - "gpu_ops/recover_decode_task.cu", "metax_ops/moe_dispatch.cu", "metax_ops/moe_ffn.cu", "metax_ops/moe_reduce.cu", diff --git a/fastdeploy/config.py b/fastdeploy/config.py index 11f84fbd0b..49c6be4c73 100644 --- a/fastdeploy/config.py +++ b/fastdeploy/config.py @@ -1160,7 +1160,7 @@ def __init__(self, args): self.kv_cache_ratio = 1.0 else: self.kv_cache_ratio = 0.75 - self.enc_dec_block_num = 0 if current_platform.is_maca() else envs.FD_ENC_DEC_BLOCK_NUM + self.enc_dec_block_num = envs.FD_ENC_DEC_BLOCK_NUM self.prealloc_dec_block_slot_num_threshold = 12 self.cache_dtype = "bfloat16" self.model_cfg = None diff --git a/fastdeploy/engine/common_engine.py b/fastdeploy/engine/common_engine.py index 52f0be8a71..6fcfe2f62d 100644 --- a/fastdeploy/engine/common_engine.py +++ b/fastdeploy/engine/common_engine.py @@ -1006,7 +1006,8 @@ def _exit_sub_services(self): exit sub services """ self.running = False - self.engine_worker_queue_server.cleanup() + if hasattr(self, "engine_worker_queue_server") and self.engine_worker_queue_server is not None: + self.engine_worker_queue_server.cleanup() self.exist_task_signal.clear() self.exist_swapped_task_signal.clear() self.worker_healthy_live_signal.clear() diff --git a/fastdeploy/model_executor/layers/backends/metax/__init__.py b/fastdeploy/model_executor/layers/backends/metax/__init__.py index 568c7d9972..1cc5b8260a 100644 --- a/fastdeploy/model_executor/layers/backends/metax/__init__.py +++ b/fastdeploy/model_executor/layers/backends/metax/__init__.py @@ -13,11 +13,13 @@ # limitations under the License. from .attention.flash_attn_backend import FlashAttentionBackend +from .attention.mla_attn_metax_backend import MetaxMLAAttentionBackend from .moe.fused_moe_cutlass_metax_backend import MetaxCutlassWeightOnlyMoEMethod from .moe.fused_moe_triton_metax_backend import MetaxTritonWeightOnlyMoEMethod __all__ = [ "FlashAttentionBackend", + "MetaxMLAAttentionBackend", "MetaxTritonWeightOnlyMoEMethod", "MetaxCutlassWeightOnlyMoEMethod", ] diff --git a/fastdeploy/model_executor/layers/backends/metax/attention/mla_attn_metax_backend.py b/fastdeploy/model_executor/layers/backends/metax/attention/mla_attn_metax_backend.py new file mode 100644 index 0000000000..ab34b9b6ee --- /dev/null +++ b/fastdeploy/model_executor/layers/backends/metax/attention/mla_attn_metax_backend.py @@ -0,0 +1,444 @@ +""" +# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" + +from __future__ import annotations + +import math +import os +from dataclasses import dataclass, field +from typing import TYPE_CHECKING, List, Optional, Tuple + +import paddle + +from fastdeploy.model_executor.ops.gpu import ( + decode_mla_write_cache, + get_block_shape_and_split_kv_block, + prefill_mla_write_cache, +) + +if TYPE_CHECKING: + from fastdeploy.model_executor.forward_meta import ForwardMeta + +from fastdeploy.config import FDConfig +from fastdeploy.model_executor.layers.attention.attention import Attention +from fastdeploy.model_executor.layers.attention.base_attention_backend import ( + AttentionBackend, + AttentionMetadata, +) +from fastdeploy.model_executor.layers.attention.utils import init_rank_and_device_id +from fastdeploy.model_executor.layers.backends.metax.attention.flash_attention_interface import ( + flash_attn_unpadded_func, +) + + +def yarn_get_mscale(scale=1, mscale=1): + """ """ + if scale <= 1: + return 1.0 + return 0.1 * mscale * math.log(scale) + 1.0 + + +@dataclass +class MLAAttentionMetadata(AttentionMetadata): + """ + MLAAttentionMetadata for Multi-Layer Attention + """ + + _dtype: paddle.dtype = paddle.bfloat16 + encoder_max_partition_size: int = 32768 + max_partition_size: int = 32768 + block_tables: Optional[paddle.Tensor] = None + rotary_embs: Optional[paddle.Tensor] = None + attn_mask: Optional[paddle.Tensor] = None + _fuse_kernel_compute_dtype: str = "bf16" + + # pd_disaggregation + kv_signal_metadata: Optional[paddle.Tensor] = None + kv_signal_data_list: List[Optional[paddle.Tensor]] = field(default_factory=list) + + max_enc_len_this_time: Optional[paddle.Tensor] = None + max_dec_len_this_time: Optional[paddle.Tensor] = None + max_kv_len_this_time: Optional[paddle.Tensor] = None + + +class MetaxMLAAttentionBackend(AttentionBackend): + """ + MLA Attention Backend implementation. + """ + + __infer_dynamic_dims_fields__ = ["attention_metadata"] + attention_metadata: MLAAttentionMetadata + flash_attn_func: callable = None + + def __init__( + self, + fd_config: FDConfig, + kv_num_heads: int, + num_heads: int, + head_dim: int, + encoder_block_shape_q: int = -1, + decoder_block_shape_q: int = -1, + ) -> None: + """ + MLAAttentionBackend __init__ + """ + super().__init__() + self.attention_metadata: MLAAttentionMetadata = None + + # 基础配置 + self.block_size: int = fd_config.cache_config.block_size + self.max_seq_len: int = fd_config.model_config.max_model_len + self.rope_theta: float = ( + 10000.0 if fd_config.model_config.rope_theta is None else fd_config.model_config.rope_theta + ) + self.rope_3d: bool = getattr(fd_config.model_config, "rope_3d", False) + self.causal: bool = getattr(fd_config.model_config, "causal", True) + self.speculative_method: str = fd_config.speculative_config.method + self.use_speculate: bool = self.speculative_method is not None + self.speculate_max_draft_token_num: int = fd_config.speculative_config.num_speculative_tokens + self.keep_pd_step_flag: bool = fd_config.speculative_config.model_type == "mtp" + self.num_layers_draft_model: int = int(fd_config.speculative_config.method in ["mtp"]) + + self.kv_num_heads: int = kv_num_heads + self.num_heads: int = num_heads + self.group_size: int = self.num_heads // self.kv_num_heads + self.head_dim: int = fd_config.model_config.head_dim + self.num_layers: int = fd_config.model_config.num_hidden_layers + self.encoder_block_shape_q: int = encoder_block_shape_q + self.decoder_block_shape_q: int = decoder_block_shape_q + + # For Multi Head Latent Attention + self.kv_lora_rank: int = fd_config.model_config.kv_lora_rank + self.qk_rope_head_dim: int = fd_config.model_config.qk_rope_head_dim + self.qk_head_dim: int = fd_config.model_config.qk_nope_head_dim + fd_config.model_config.qk_rope_head_dim + self.attn_softmax_scale: float = self.qk_head_dim**-0.5 + if fd_config.model_config.rope_scaling: + mscale_all_dim = fd_config.model_config.rope_scaling.get("mscale_all_dim", False) # 1.0 + scaling_factor = fd_config.model_config.rope_scaling["factor"] # 40 + mscale = yarn_get_mscale(scaling_factor, float(mscale_all_dim)) + self.attn_softmax_scale = self.attn_softmax_scale * mscale * mscale + + self.pd_disaggregation_mode: str = fd_config.parallel_config.pd_disaggregation_mode + + self.start_layer_index: int = fd_config.model_config.start_layer_index + self.device_id: int = os.getenv("CUDA_VISIBLE_DEVICES", None) + + self.rank, self.device_id = init_rank_and_device_id(fd_config) + + self.flash_attn_func = flash_attn_unpadded_func + self.flash_attn_kwargs = {"softmax_scale": self.attn_softmax_scale} + + def init_attention_metadata(self, forward_meta: ForwardMeta): + """Initialize attention metadata hence all layers in the forward pass can reuse it.""" + metadata = MLAAttentionMetadata() + metadata.max_partition_size = 32768 + metadata.encoder_max_partition_size = self.max_seq_len + metadata._dtype = paddle.get_default_dtype() + if metadata._dtype == "bfloat16": + metadata._fuse_kernel_compute_dtype = "bf16" + elif metadata._dtype == "float16": + metadata._fuse_kernel_compute_dtype = "fp16" + elif metadata._dtype == "float32": + metadata._fuse_kernel_compute_dtype = "fp32" + + metadata.block_tables = forward_meta.block_tables + metadata.rotary_embs = forward_meta.rotary_embs + metadata.attn_mask = forward_meta.attn_mask + metadata.pre_caches_length = forward_meta.pre_caches_length + + get_block_shape_and_split_kv_block( + forward_meta.seq_lens_encoder, + forward_meta.seq_lens_decoder, + forward_meta.seq_lens_this_time, + forward_meta.decoder_batch_ids, # decoder_batch_ids_per_ctax + forward_meta.decoder_tile_ids_per_batch, # decoder_chunk_ids_per_ctax_each_batch + forward_meta.decoder_num_blocks_cpu, + forward_meta.decoder_num_blocks_device, + forward_meta.decoder_chunk_size_device, + forward_meta.max_len_tensor_cpu, + forward_meta.encoder_batch_ids, + forward_meta.encoder_tile_ids_per_batch, + forward_meta.encoder_num_blocks_x_cpu, + forward_meta.kv_batch_ids, + forward_meta.kv_tile_ids_per_batch, + forward_meta.kv_num_blocks_x_cpu, + self.encoder_block_shape_q, + self.decoder_block_shape_q, + self.group_size, + self.block_size, + self.speculate_max_draft_token_num + 1, + ) + + # MLA + metadata.max_enc_len_this_time = forward_meta.max_len_tensor_cpu[1] + metadata.max_dec_len_this_time = forward_meta.max_len_tensor_cpu[2] + metadata.max_kv_len_this_time = forward_meta.max_len_tensor_cpu[8] + + # pd_disaggregation + metadata.kv_signal_data_list = [None] * self.num_layers + + self.attention_metadata: AttentionMetadata = metadata + + def get_attntion_meta(self) -> AttentionMetadata: + """get_attntion_meta""" + return self.attention_metadata + + def get_kv_cache_shape( + self, + max_num_blocks: int, + kv_cache_quant_type: str = None, + ) -> Tuple[int, int, int, int]: + """ + Calculate kv cache shape for MLA + """ + return ( + max_num_blocks, + 1, + self.block_size, + self.kv_lora_rank + self.qk_rope_head_dim, + ) + + def forward_extend( + self, + q: paddle.Tensor, + k: paddle.Tensor, + v: paddle.Tensor, + qkv: paddle.Tensor, + compressed_kv: paddle.Tensor, + k_pe: paddle.Tensor, + layer: Attention, + forward_meta: ForwardMeta, + ) -> paddle.Tensor: + """ + Prefill阶段的前向传播 + """ + metadata = self.attention_metadata + + latent_cache = forward_meta.caches[layer.layer_id] if hasattr(forward_meta, "caches") else None + + # 写入缓存 + prefill_mla_write_cache( + compressed_kv, + k_pe, + latent_cache, + forward_meta.seq_lens_encoder, + forward_meta.seq_lens_decoder, + forward_meta.batch_id_per_token, + forward_meta.cu_seqlens_q, + metadata.block_tables, + "none", + getattr(forward_meta, "max_input_length", -1), + ) + + # Flash注意力计算 + fmha_out = self.flash_attn_func( + q, + k, + v, + forward_meta.cu_seqlens_q, + forward_meta.cu_seqlens_k, + metadata.max_enc_len_this_time, + metadata.max_enc_len_this_time, + causal=self.causal, + **self.flash_attn_kwargs, + )[0] + + return fmha_out + + def _run_single_flash_mla(self, query, latent_cache, block_tables, seq_lens, draft_token_num): + from flash_mla_paddle import flash_mla_with_kvcache, get_mla_metadata + + qk_head_dim = self.kv_lora_rank + self.qk_rope_head_dim + v_head_dim = self.kv_lora_rank + q_head_num = self.num_heads + kv_head_num = latent_cache.shape[2] + + query = query.reshape([-1, draft_token_num, q_head_num, qk_head_dim]) + tile_scheduler_metadata, num_splits = get_mla_metadata( + seq_lens, draft_token_num * q_head_num // kv_head_num, kv_head_num + ) + + out, _ = flash_mla_with_kvcache( + query, + latent_cache, + block_tables, + seq_lens, + v_head_dim, + tile_scheduler_metadata, + num_splits, + softmax_scale=self.attn_softmax_scale, + causal=True, + ) + + return out.reshape([-1, q_head_num, v_head_dim]) + + def compute_flash_mla(self, query, latent_cache, forward_meta): + block_tables = self.attention_metadata.block_tables + seq_lens_decoder = forward_meta.seq_lens_decoder + seq_lens_this_time = forward_meta.seq_lens_this_time + assert block_tables is not None and seq_lens_decoder is not None and seq_lens_this_time is not None + assert block_tables.shape[0] == seq_lens_decoder.shape[0] + + query = query.reshape([-1, self.num_heads, self.kv_lora_rank + self.qk_rope_head_dim]) + latent_cache = latent_cache.transpose([0, 2, 1, 3]) + + seq_lens_decoder = seq_lens_decoder.squeeze(-1) + seq_lens_this_time = seq_lens_this_time.squeeze(-1) + non_zero_index = paddle.nonzero(seq_lens_this_time).flatten() + seq_lens_decoder = seq_lens_decoder[non_zero_index] + seq_lens_this_time = seq_lens_this_time[non_zero_index] + block_tables = block_tables[non_zero_index] + + max_seq_lens_this_time = seq_lens_this_time.max().item() + min_seq_lens_this_time = seq_lens_this_time.min().item() + + if max_seq_lens_this_time == min_seq_lens_this_time: + return self._run_single_flash_mla( + query, latent_cache, block_tables, seq_lens_decoder + seq_lens_this_time, max_seq_lens_this_time + ) + else: + max_draft_token_num = self.speculate_max_draft_token_num + 1 + seq_lens_this_time_cpu = seq_lens_this_time.cpu() + bsz = seq_lens_this_time_cpu.shape[0] + qk_head_dim = self.kv_lora_rank + self.qk_rope_head_dim + batched_query = paddle.zeros( + [bsz * max_draft_token_num, self.num_heads, qk_head_dim], dtype=query.dtype + ).to(query.place) + full_token_index = paddle.arange(bsz * max_draft_token_num, dtype="int32").reshape( + [bsz, max_draft_token_num] + ) + token_mapping_index = [] + for group_id in range(bsz): + seq_len = seq_lens_this_time_cpu[group_id] + token_mapping_index.append(full_token_index[group_id, :seq_len]) + token_mapping_index = paddle.concat(token_mapping_index) + assert token_mapping_index.shape[0] == query.shape[0] + batched_query[token_mapping_index] = query + seq_lens_this_time = paddle.full_like(seq_lens_this_time, fill_value=max_draft_token_num) + out = self._run_single_flash_mla( + batched_query, latent_cache, block_tables, seq_lens_decoder + seq_lens_this_time, max_draft_token_num + ) + return out[token_mapping_index] + + def forward_decode( + self, + q: paddle.Tensor, + k: paddle.Tensor, + v: paddle.Tensor, + qkv: paddle.Tensor, + compressed_kv: paddle.Tensor, + k_pe: paddle.Tensor, + layer: Attention, + forward_meta: ForwardMeta, + ) -> paddle.Tensor: + """ + Decode阶段的前向传播 + """ + metadata = self.attention_metadata + + latent_cache = forward_meta.caches[layer.layer_id] if hasattr(forward_meta, "caches") else None + + # 获取推测解码参数 + speculate_decoder = self.speculative_method is not None + + # 写入缓存 + decode_mla_write_cache( + compressed_kv, + k_pe, + latent_cache, + forward_meta.seq_lens_decoder, + forward_meta.seq_lens_encoder, + forward_meta.batch_id_per_token, + forward_meta.cu_seqlens_q, + metadata.block_tables, + "none", + self.max_seq_len, + speculate_decoder, + ) + + # 多头潜在注意力计算 + fmha_out = self.compute_flash_mla(q, latent_cache, forward_meta) + + return fmha_out + + def forward_mixed( + self, + q: paddle.Tensor, + k: paddle.Tensor, + v: paddle.Tensor, + qkv: paddle.Tensor, + compressed_kv: paddle.Tensor, + k_pe: paddle.Tensor, + layer: Attention, + forward_meta: ForwardMeta, + ) -> paddle.Tensor: + """ + Mixed模式的前向传播 + """ + metadata = self.attention_metadata + speculate_decoder = self.speculative_method is not None + + latent_cache = forward_meta.caches[layer.layer_id] if hasattr(forward_meta, "caches") else None + + if k is not None: + prefill_mla_write_cache( + compressed_kv, + k_pe, + latent_cache, + forward_meta.seq_lens_encoder, + forward_meta.seq_lens_decoder, + forward_meta.batch_id_per_token, + forward_meta.cu_seqlens_q, + metadata.block_tables, + "none", + self.max_seq_len, + ) + + # FA + fmha_out = self.flash_attn_func( + q, + k, + v, + forward_meta.cu_seqlens_q, + forward_meta.cu_seqlens_k, + metadata.max_enc_len_this_time, + metadata.max_enc_len_this_time, + causal=self.causal, + **self.flash_attn_kwargs, + )[0] + + return fmha_out + + # Decode + if k is None: + decode_mla_write_cache( + compressed_kv, + k_pe, + latent_cache, + forward_meta.seq_lens_decoder, + forward_meta.seq_lens_encoder, + forward_meta.batch_id_per_token, + forward_meta.cu_seqlens_q, + metadata.block_tables, + "none", + self.max_seq_len, + speculate_decoder, + ) + + # 多头潜在注意力计算 + fmha_out = self.compute_flash_mla(q, latent_cache, forward_meta) + + return fmha_out diff --git a/fastdeploy/model_executor/layers/backends/metax/moe/fused_moe_triton_metax_backend.py b/fastdeploy/model_executor/layers/backends/metax/moe/fused_moe_triton_metax_backend.py index 0bd623998a..db472380ff 100644 --- a/fastdeploy/model_executor/layers/backends/metax/moe/fused_moe_triton_metax_backend.py +++ b/fastdeploy/model_executor/layers/backends/metax/moe/fused_moe_triton_metax_backend.py @@ -19,8 +19,10 @@ import fastdeploy from fastdeploy.distributed.communication import tensor_model_parallel_all_reduce +from fastdeploy.model_executor.layers.moe.moe import get_moe_scores from fastdeploy.model_executor.layers.quantization.quant_base import QuantMethodBase from fastdeploy.model_executor.ops.gpu import tritonmoe_preprocess +from fastdeploy.model_executor.utils import TensorTracker, set_weight_attrs from fastdeploy.utils import ceil_div from .triton_moe_kernels import fused_moe_kernel_paddle @@ -65,43 +67,74 @@ def create_weights(self, layer: nn.Layer, **extra_weight_attrs): layer.moe_intermediate_size, layer.hidden_size, ] - setattr( - layer, - up_gate_proj_weight_name, - layer.create_parameter( + # TODO(bukejiyu): remove v1 loader check when v0 loader is removed + if self.quant_config.is_checkpoint_bf16 and layer.fd_config.load_config.load_choices == "default_v1": + layer.up_gate_proj_weight = layer.create_parameter( shape=self.up_gate_proj_weight_shape, - dtype=self.weight_dtype, + dtype=layer.weight_dtype, default_initializer=paddle.nn.initializer.Constant(0), - ), - ) - setattr( - layer, - down_proj_weight_name, - layer.create_parameter( + ) + + layer.down_proj_weight = layer.create_parameter( shape=self.down_proj_weight_shape, - dtype=self.weight_dtype, + dtype=layer.weight_dtype, default_initializer=paddle.nn.initializer.Constant(0), - ), - ) - # weight_scale - setattr( - layer, - self.added_scale_attrs[0], - layer.create_parameter( - shape=[layer.num_local_experts, layer.moe_intermediate_size * 2], - dtype=self.default_dtype, - default_initializer=paddle.nn.initializer.Constant(0), - ), - ) - setattr( - layer, - self.added_scale_attrs[1], - layer.create_parameter( - shape=[layer.num_local_experts, layer.hidden_size], - dtype=self.default_dtype, - default_initializer=paddle.nn.initializer.Constant(0), - ), - ) + ) + extra_weight_attrs["weight_need_transpose"] = extra_weight_attrs.get("model_format") == "torch" + + set_weight_attrs( + layer.up_gate_proj_weight, + { + **extra_weight_attrs, + "tensor_track": TensorTracker(shape=layer.up_gate_proj_weight.shape, output_dim=True), + }, + ) + set_weight_attrs( + layer.down_proj_weight, + { + **extra_weight_attrs, + "tensor_track": TensorTracker(shape=layer.down_proj_weight.shape, output_dim=False), + }, + ) + else: + setattr( + layer, + up_gate_proj_weight_name, + layer.create_parameter( + shape=self.up_gate_proj_weight_shape, + dtype=self.weight_dtype, + default_initializer=paddle.nn.initializer.Constant(0), + ), + ) + setattr( + layer, + down_proj_weight_name, + layer.create_parameter( + shape=self.down_proj_weight_shape, + dtype=self.weight_dtype, + default_initializer=paddle.nn.initializer.Constant(0), + ), + ) + # weight_scale + setattr( + layer, + self.added_scale_attrs[0], + layer.create_parameter( + shape=[layer.num_local_experts, layer.moe_intermediate_size * 2], + dtype=self.default_dtype, + default_initializer=paddle.nn.initializer.Constant(0), + ), + ) + setattr( + layer, + self.added_scale_attrs[1], + layer.create_parameter( + shape=[layer.num_local_experts, layer.hidden_size], + dtype=self.default_dtype, + default_initializer=paddle.nn.initializer.Constant(0), + ), + ) + # support cache feature in future @paddle.no_grad() def process_loaded_weights(self, layer: nn.Layer, state_dict): @@ -114,6 +147,8 @@ def process_loaded_weights(self, layer: nn.Layer, state_dict): algo = layer.quant_method.quant_config.name() + assert algo == "wint8" + assert up_gate_proj_weights[0].shape == [ layer.hidden_size, layer.moe_intermediate_size * 2, @@ -143,6 +178,63 @@ def process_loaded_weights(self, layer: nn.Layer, state_dict): getattr(layer, weight_name).set_value(quanted_weight) getattr(layer, scale_name).set_value(quanted_weight_scale) + @paddle.no_grad() + def process_weights_after_loading(self, layer): + """ """ + if not self.quant_config.is_checkpoint_bf16: + return + + algo = layer.quant_method.quant_config.name() + assert algo == "wint8" + max_bound = 127 + weight_id_map = {"gate_up": 0, "down": 1} + if ( + hasattr(layer.up_gate_proj_weight, "tensor_track") + and layer.up_gate_proj_weight.tensor_track is not None + and layer.up_gate_proj_weight.tensor_track.is_fully_copied() + ): + weight_type = "gate_up" + layer.up_gate_proj_weight.tensor_track = None + else: + weight_type = "down" + layer.down_proj_weight.tensor_track = None + + # weight + weight_name = self.added_weight_attrs[weight_id_map[weight_type]] + # scale + scale_name = self.added_scale_attrs[weight_id_map[weight_type]] + + weight_tensor = getattr(layer, weight_name) + quanted_weight_scale = weight_tensor.abs().max(axis=1) + quanted_weight = weight_tensor / quanted_weight_scale[:, None, :] * max_bound + quanted_weight = paddle.round(quanted_weight).astype("int8") + quanted_weight_scale = quanted_weight_scale / max_bound + + getattr(layer, weight_name).value().get_tensor()._clear() + + # create weight + setattr( + layer, + weight_name, + layer.create_parameter( + shape=weight_tensor.shape, + dtype=quanted_weight.dtype, + default_initializer=paddle.nn.initializer.Constant(0), + ), + ) + # create scale + setattr( + layer, + scale_name, + layer.create_parameter( + shape=quanted_weight_scale.shape, + dtype=quanted_weight_scale.dtype, + default_initializer=paddle.nn.initializer.Constant(0), + ), + ) + getattr(layer, weight_name).copy_(quanted_weight, False) + getattr(layer, scale_name).copy_(quanted_weight_scale, False) + @paddle.no_grad() def apply( self, @@ -157,38 +249,38 @@ def apply( token_num = x.shape[0] top_k = layer.top_k num_local_experts = layer.num_local_experts - top_k = layer.top_k moe_intermediate_size = layer.moe_intermediate_size hidden_size = layer.hidden_size - topk_ids, topk_weights = fastdeploy.model_executor.ops.gpu.moe_topk_select( - gate_out, - layer.gate_correction_bias, - layer.top_k, - True, # apply_norm_weight - False, - ) - + if layer.topk_method == "noaux_tc": + gate_out, topk_weights, topk_ids = get_moe_scores( + gate_out, + layer.n_group, + layer.topk_group, + layer.top_k, + layer.routed_scaling_factor, + layer.gate_correction_bias, + getattr(layer, "renormalize", True), + ) + else: + topk_ids, topk_weights = fastdeploy.model_executor.ops.gpu.moe_topk_select( + gate_out, + layer.gate_correction_bias, + layer.top_k, + True, # apply_norm_weight + False, + ) up_gate_proj_out = paddle.empty( [token_num * top_k, moe_intermediate_size * 2], dtype=x.dtype, ) - if self.quant_config is not None: - config = { - "BLOCK_SIZE_M": 32, - "BLOCK_SIZE_N": 64, - "BLOCK_SIZE_K": 64, - "GROUP_SIZE_M": 4, - } - else: - config = { - "BLOCK_SIZE_M": 32, - "BLOCK_SIZE_N": 64, - "BLOCK_SIZE_K": 64, - "GROUP_SIZE_M": 4, - } - + config = { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 4, + } sorted_token_ids, expert_ids, num_tokens_post_padded = tritonmoe_preprocess( topk_ids, num_local_experts, config["BLOCK_SIZE_M"] ) @@ -237,6 +329,7 @@ def apply( compute_type_enum=1, use_fp8_w8a8=False, use_int8_w8a16=True, + per_channel_quant=False, even_Ks=hidden_size % config["BLOCK_SIZE_K"] == 0, ) @@ -289,11 +382,12 @@ def apply( compute_type_enum=1, use_fp8_w8a8=False, use_int8_w8a16=True, + per_channel_quant=False, even_Ks=moe_intermediate_size % config["BLOCK_SIZE_K"] == 0, ) down_proj_out.reshape_([token_num, top_k, hidden_size]) out = down_proj_out.sum(axis=1) - if layer.tp_size > 1: - tensor_model_parallel_all_reduce(out, layer.fd_config.parallel_config.tp_group) + if layer.reduce_results and layer.tp_size > 1: + out = tensor_model_parallel_all_reduce(out, layer.fd_config.parallel_config.tp_group) return out diff --git a/fastdeploy/model_executor/layers/backends/metax/moe/triton_moe_kernels.py b/fastdeploy/model_executor/layers/backends/metax/moe/triton_moe_kernels.py index a359330c55..f63641a664 100644 --- a/fastdeploy/model_executor/layers/backends/metax/moe/triton_moe_kernels.py +++ b/fastdeploy/model_executor/layers/backends/metax/moe/triton_moe_kernels.py @@ -16,7 +16,7 @@ import triton.language as tl -@triton.jit +@triton.jit() def fused_moe_kernel_paddle( a_ptr, b_ptr, @@ -30,20 +30,20 @@ def fused_moe_kernel_paddle( # Matrix dimensions max_possible_num_post_padded, num_valid_tokens, - N, - K, - stride_am, - stride_ak, - stride_be, - stride_bk, - stride_bn, - stride_cm, - stride_cn, - stride_asm, - stride_ask, - stride_bse, - stride_bsk, - stride_bsn, + N: tl.constexpr, + K: tl.constexpr, + stride_am: tl.constexpr, + stride_ak: tl.constexpr, + stride_be: tl.constexpr, + stride_bk: tl.constexpr, + stride_bn: tl.constexpr, + stride_cm: tl.constexpr, + stride_cn: tl.constexpr, + stride_asm: tl.constexpr, + stride_ask: tl.constexpr, + stride_bse: tl.constexpr, + stride_bsk: tl.constexpr, + stride_bsn: tl.constexpr, # Block size for block-wise fp8 quantization group_n: tl.constexpr, group_k: tl.constexpr, @@ -57,6 +57,7 @@ def fused_moe_kernel_paddle( compute_type_enum: tl.constexpr, use_fp8_w8a8: tl.constexpr, use_int8_w8a16: tl.constexpr, + per_channel_quant: tl.constexpr, even_Ks: tl.constexpr, ): """ @@ -119,6 +120,13 @@ def fused_moe_kernel_paddle( a_scale_ptrs = a_scale_ptr + (offs_token // top_k) * stride_asm offs_bsn = offs_bn // group_n b_scale_ptrs = b_scale_ptr + off_experts * stride_bse + offs_bsn * stride_bsn + # channel-wise + elif per_channel_quant: + b_scale_ptrs = b_scale_ptr + off_experts * stride_bse + offs_bn[None, :] * stride_bsn + b_scale = tl.load(b_scale_ptrs) + # Load per-token scale for activations + a_scale_ptrs = a_scale_ptr + (offs_token // top_k) * stride_asm + a_scale = tl.load(a_scale_ptrs, mask=token_mask, other=0.0)[:, None] else: # (Zkk): every expert has one activation scale and weight scale. a_scale = tl.load(a_scale_ptr + off_experts) diff --git a/fastdeploy/model_executor/layers/rotary_embedding.py b/fastdeploy/model_executor/layers/rotary_embedding.py index 387601d2a2..2d7e59af6d 100644 --- a/fastdeploy/model_executor/layers/rotary_embedding.py +++ b/fastdeploy/model_executor/layers/rotary_embedding.py @@ -23,7 +23,7 @@ from fastdeploy.config import ModelConfig from fastdeploy.platforms import current_platform -if current_platform.is_cuda(): +if current_platform.is_cuda() or current_platform.is_maca(): from fastdeploy.model_executor.ops.gpu import fused_rotary_position_encoding from .utils import CpuGuard diff --git a/fastdeploy/model_executor/model_loader/default_loader.py b/fastdeploy/model_executor/model_loader/default_loader.py index 4ec3f13912..dc6bedc918 100644 --- a/fastdeploy/model_executor/model_loader/default_loader.py +++ b/fastdeploy/model_executor/model_loader/default_loader.py @@ -43,12 +43,12 @@ def download_model(self, model_config: ModelConfig) -> None: def clean_memory_fragments(self, state_dict: dict) -> None: """clean_memory_fragments""" - if current_platform.is_cuda(): + if current_platform.is_cuda() or current_platform.is_maca(): if state_dict: for k, v in state_dict.items(): if isinstance(v, paddle.Tensor): v.value().get_tensor()._clear() - paddle.device.cuda.empty_cache() + paddle.device.empty_cache() paddle.device.synchronize() @measure_time() diff --git a/fastdeploy/model_executor/model_loader/default_loader_v1.py b/fastdeploy/model_executor/model_loader/default_loader_v1.py index 09b85cbe7e..83f13382f8 100644 --- a/fastdeploy/model_executor/model_loader/default_loader_v1.py +++ b/fastdeploy/model_executor/model_loader/default_loader_v1.py @@ -43,8 +43,8 @@ def download_model(self, model_config: ModelConfig) -> None: def clean_memory_fragments(self) -> None: """clean_memory_fragments""" - if current_platform.is_cuda(): - paddle.device.cuda.empty_cache() + if current_platform.is_cuda() or current_platform.is_maca(): + paddle.device.empty_cache() paddle.device.synchronize() @save_model() diff --git a/fastdeploy/model_executor/models/deepseek_v3.py b/fastdeploy/model_executor/models/deepseek_v3.py index 5ff2bc4cf3..556cf85e03 100644 --- a/fastdeploy/model_executor/models/deepseek_v3.py +++ b/fastdeploy/model_executor/models/deepseek_v3.py @@ -55,7 +55,7 @@ ) from fastdeploy.platforms import current_platform -if current_platform.is_cuda(): +if current_platform.is_cuda() or current_platform.is_maca(): from fastdeploy.model_executor.ops.gpu import ( get_position_ids_and_mask_encoder_batch, ) diff --git a/fastdeploy/model_executor/models/ernie4_5_vl/image_op.py b/fastdeploy/model_executor/models/ernie4_5_vl/image_op.py index e5fbb3be3f..7f3e1b13f1 100644 --- a/fastdeploy/model_executor/models/ernie4_5_vl/image_op.py +++ b/fastdeploy/model_executor/models/ernie4_5_vl/image_op.py @@ -16,7 +16,7 @@ from fastdeploy.platforms import current_platform -if current_platform.is_cuda(): +if current_platform.is_cuda() or current_platform.is_maca(): from fastdeploy.model_executor.ops.gpu import ( text_image_gather_scatter, text_image_index_out, @@ -32,6 +32,6 @@ text_image_index_out, ) else: - raise ImportError("Unsupported platform, only support CUDA and XPU") + raise ImportError("Unsupported platform, only support CUDA, MACA and XPU") __all__ = ["text_image_gather_scatter", "text_image_index_out"] diff --git a/fastdeploy/platforms/maca.py b/fastdeploy/platforms/maca.py index 250cebf6e1..384886294d 100644 --- a/fastdeploy/platforms/maca.py +++ b/fastdeploy/platforms/maca.py @@ -60,6 +60,9 @@ def get_attention_backend_cls(cls, selected_backend: _Backend): elif selected_backend == _Backend.APPEND_ATTN: logger.info("Using FLASH ATTN backend to instead of attend attention.") return "fastdeploy.model_executor.layers.backends.metax.attention.flash_attn_backend.FlashAttentionBackend" + elif selected_backend == _Backend.MLA_ATTN: + logger.info("Using MLA ATTN backend.") + return "fastdeploy.model_executor.layers.backends.metax.attention.mla_attn_metax_backend.MetaxMLAAttentionBackend" else: raise ValueError( "Invalid attention backend you specified.\n" diff --git a/fastdeploy/worker/metax_model_runner.py b/fastdeploy/worker/metax_model_runner.py index f6a2c0b15e..5382749b55 100644 --- a/fastdeploy/worker/metax_model_runner.py +++ b/fastdeploy/worker/metax_model_runner.py @@ -15,40 +15,50 @@ """ import os +import queue import time -from typing import List, Optional +from threading import Thread +from typing import List, Optional, cast import numpy as np import paddle +import zmq from paddle import nn from paddleformers.utils.log import logger from fastdeploy import envs from fastdeploy.config import FDConfig +from fastdeploy.engine.pooling_params import PoolingParams from fastdeploy.engine.request import Request, RequestType +from fastdeploy.engine.tasks import PoolingTask from fastdeploy.input.ernie4_5_vl_processor import DataProcessor +from fastdeploy.inter_communicator import IPCSignal, ZmqIpcClient from fastdeploy.model_executor.forward_meta import ForwardMeta from fastdeploy.model_executor.graph_optimization.utils import ( profile_run_guard, sot_warmup_guard, ) -from fastdeploy.model_executor.guided_decoding import get_guided_backend -from fastdeploy.model_executor.guided_decoding.base_guided_decoding import ( +from fastdeploy.model_executor.guided_decoding import ( LogitsProcessorBase, + get_guided_backend, ) from fastdeploy.model_executor.layers.attention import get_attention_backend from fastdeploy.model_executor.layers.attention.base_attention_backend import ( AttentionBackend, ) +from fastdeploy.model_executor.layers.pool.metadata import PoolingMetadata from fastdeploy.model_executor.layers.rotary_embedding import get_rope, get_rope_3d from fastdeploy.model_executor.layers.sample.meta_data import SamplingMetadata from fastdeploy.model_executor.layers.sample.sampler import Sampler, SpeculativeSampler from fastdeploy.model_executor.model_loader import get_model_loader from fastdeploy.model_executor.models.ernie4_5_vl.modeling_resampler import ScatterOp +from fastdeploy.model_executor.models.interfaces_base import FdModelForPooling from fastdeploy.model_executor.ops.gpu import ( recover_decode_task, + set_data_ipc, set_value_by_flags_and_idx, share_external_data, + speculate_schedule_cache, ) from fastdeploy.model_executor.pre_and_post_process import ( post_process, @@ -56,6 +66,7 @@ rebuild_padding, step_cuda, ) +from fastdeploy.output.pooler import PoolerOutput from fastdeploy.spec_decode import MTPProposer, NgramProposer from fastdeploy.worker.model_runner_base import ModelRunnerBase from fastdeploy.worker.output import ModelOutputData, ModelRunnerOutput @@ -79,14 +90,12 @@ def __init__( self.speculative_decoding = self.speculative_method is not None self.enable_logprob = fd_config.model_config.enable_logprob self.enable_early_stop = self.fd_config.early_stop_config.enable_early_stop - - self.guided_backend = None - if self.fd_config.structured_outputs_config.guided_decoding_backend != "off": - self.guided_backend = get_guided_backend(fd_config=self.fd_config) + self.is_pooling_model = self.fd_config.model_config.runner_type == "pooling" # VL model config: if self.enable_mm: - self._init_image_preprocess() + if "ernie" in self.fd_config.model_config.model_type: + self._init_image_preprocess() self.amp_black = [ "reduce_sum", @@ -111,14 +120,19 @@ def __init__( else: self.sampler = SpeculativeSampler(fd_config) + self.guided_backend = None + if self.fd_config.structured_outputs_config.guided_decoding_backend != "off": + self.guided_backend = get_guided_backend(fd_config=self.fd_config) + self.sampler.set_reasoning_parser(self.guided_backend.get_reasoning_parser()) + # Lazy initialize kv cache after model loading # self.kv_caches: list[paddle.Tensor] = [] - # Cuda Graph - self.graph_opt_level = self.graph_opt_config.graph_opt_level + # CUDA Graph self.use_cudagraph = self.graph_opt_config.use_cudagraph self.cudagraph_capture_sizes = list(reversed(self.graph_opt_config.cudagraph_capture_sizes)) self.sot_warmup_sizes = self.graph_opt_config.sot_warmup_sizes + self.cudagraph_only_prefill = self.graph_opt_config.cudagraph_only_prefill # Initialize share inputs self._init_share_inputs(self.scheduler_config.max_num_seqs) @@ -134,24 +148,86 @@ def __init__( # In the future, we will expand it as a list. self.attn_backends: list[AttentionBackend] = [] # self.attn_metadatas: list[AttentionMetadata] = [] - self.initialize_attn_backend() + self._initialize_attn_backend() # Forward meta store the global meta information of the forward self.forward_meta: ForwardMeta = None # Postprocess Env params - os.environ["INFERENCE_MSG_QUEUE_ID"] = str( - self.local_rank + int(self.parallel_config.engine_worker_queue_port) - ) + os.environ["INFERENCE_MSG_QUEUE_ID"] = str(self.parallel_config.engine_worker_queue_port) + logger.info(f"queue id is {str(self.parallel_config.engine_worker_queue_port)}") + + self.zmq_client = None + self.async_output_queue = None + if envs.FD_USE_GET_SAVE_OUTPUT_V1: + logger.info(f"zmq client get_save_output_rank{local_rank}") + self.zmq_client = ZmqIpcClient(name=f"get_save_output_rank{local_rank}", mode=zmq.PUSH) + self.zmq_client.connect() + self.zmq_client.socket.SNDTIMEO = 3000 + self.async_output_queue: queue.Queue = queue.Queue() + self.async_output_copy_thread = Thread( + target=self._async_output_busy_loop, + daemon=True, + name="WorkerAsyncOutputCopy", + ) + self.async_output_copy_thread.start() + + def _async_output_busy_loop(self): + """Entrypoint for the thread which handles outputs asynchronously.""" + while True: + try: + output = self.async_output_queue.get() + self.zmq_client.send_pyobj(output) + except Exception as e: + logger.exception("Exception in async output loop: %s", e) def exist_prefill(self): """ check whether prefill stage exist """ - if int(paddle.max(self.share_inputs["seq_lens_encoder"])) != 0: - return 1 - else: - return 0 + return int(paddle.max(self.share_inputs["seq_lens_encoder"])) > 0 + + def exist_decode(self): + """ + check whether decode stage exist + """ + return int(paddle.max(self.share_inputs["seq_lens_decoder"])) > 0 + + def only_prefill(self): + """ + check whether prefill only + """ + if_only_prefill = True + decode_exists = None + if self.fd_config.parallel_config.use_ep and self.fd_config.scheduler_config.splitwise_role == "mixed": + only_prefill_batch_list = [] + decode_exists = self.exist_decode() + paddle.distributed.all_gather_object(only_prefill_batch_list, not decode_exists) + if_only_prefill = all(only_prefill_batch_list) + + if_only_prefill = if_only_prefill and not (decode_exists if decode_exists is not None else self.exist_decode()) + + return if_only_prefill + + def only_decode(self): + """ + check whether decode only + """ + # Update Batch type for cuda graph for if_only_decode + if_only_decode = True + prefill_exists = None + # mix ep in single node + if self.fd_config.parallel_config.use_ep and self.fd_config.scheduler_config.splitwise_role == "mixed": + only_decode_batch_list = [] + prefill_exists = self.exist_prefill() + paddle.distributed.all_gather_object(only_decode_batch_list, not prefill_exists) + if_only_decode = all(only_decode_batch_list) + + if_only_decode = if_only_decode and not ( + prefill_exists if prefill_exists is not None else self.exist_prefill() + ) + + return if_only_decode def _init_speculative_proposer(self): """ @@ -188,7 +264,13 @@ def _init_logits_processor(self, request): elif request.structural_tag is not None: schemata_key = ("structural_tag", request.structural_tag) - return self.guided_backend.get_logits_processor(schemata_key=schemata_key), schemata_key + return ( + self.guided_backend.get_logits_processor( + schemata_key=schemata_key, + enable_thinking=True, + ), + schemata_key, + ) def insert_tasks_v1(self, req_dicts: List[Request], num_running_requests: int = None): """ @@ -225,7 +307,8 @@ def insert_tasks_v1(self, req_dicts: List[Request], num_running_requests: int = dtype=paddle.int64, ) vision_inputs["images"] = paddle.to_tensor( - inputs["images"][request.image_start : request.image_end], dtype="uint8" + inputs["images"][request.image_start : request.image_end], + dtype="uint8" if "ernie" in self.model_config.model_type else "bfloat16", ) vision_inputs["grid_thw"] = paddle.to_tensor( inputs["grid_thw"][request.num_image_start : request.num_image_end], dtype="int64" @@ -246,7 +329,20 @@ def insert_tasks_v1(self, req_dicts: List[Request], num_running_requests: int = position_ids, request.get("max_tokens", 2048) ) - input_ids = request.prompt_token_ids + request.output_token_ids + if request.get("enable_thinking", False) and request.get("reasoning_max_tokens", None) is not None: + # Enable thinking + self.share_inputs["max_think_lens"][idx : idx + 1, :] = request.get("reasoning_max_tokens") + self.share_inputs["limit_think_status"][idx : idx + 1, :] = 0 + else: + # Disable thinking + self.share_inputs["max_think_lens"][idx : idx + 1, :] = -1 + self.share_inputs["limit_think_status"][idx : idx + 1, :] = 0 + + if isinstance(request.prompt_token_ids, np.ndarray): + prompt_token_ids = request.prompt_token_ids.tolist() + else: + prompt_token_ids = request.prompt_token_ids + input_ids = prompt_token_ids + request.output_token_ids logger.debug( f"Handle prefill request {request} at idx {idx}, " f"{prefill_start_index=}, {prefill_end_index=}, " @@ -285,7 +381,7 @@ def insert_tasks_v1(self, req_dicts: List[Request], num_running_requests: int = has_decode_task = True continue else: # preempted task - logger.debug(f"Handle preempted request {request} at idx {idx}") + logger.info(f"Handle preempted request {request} at idx {idx}") self.share_inputs["block_tables"][idx : idx + 1, :] = -1 self.share_inputs["stop_flags"][idx : idx + 1] = True self.seq_lens_this_time_buffer[idx : idx + 1] = 0 @@ -306,6 +402,10 @@ def insert_tasks_v1(self, req_dicts: List[Request], num_running_requests: int = self.share_inputs["penalty_score"][idx : idx + 1] = request.get("repetition_penalty", 1.0) self.share_inputs["frequency_score"][idx : idx + 1] = request.get("frequency_penalty", 0.0) self.share_inputs["presence_score"][idx : idx + 1] = request.get("presence_penalty", 0.0) + self.share_inputs["temp_scaled_logprobs"][idx : idx + 1] = request.get("temp_scaled_logprobs", False) + self.share_inputs["top_p_normalized_logprobs"][idx : idx + 1] = request.get( + "top_p_normalized_logprobs", False + ) self.share_inputs["min_dec_len"][idx : idx + 1] = request.get("min_tokens", 1) self.share_inputs["max_dec_len"][idx : idx + 1] = request.get( @@ -344,6 +444,8 @@ def insert_tasks_v1(self, req_dicts: List[Request], num_running_requests: int = if has_prefill_task or has_decode_task: self.share_inputs["not_need_stop"][0] = True self.share_inputs["seq_lens_this_time"] = self.seq_lens_this_time_buffer[:num_running_requests] + if self.speculative_method in ["mtp"]: + self.proposer.insert_tasks_v1(req_dicts, num_running_requests) def insert_prefill_inputs(self, req_dicts: List[Request], num_running_requests: int = None): """ @@ -459,6 +561,15 @@ def insert_prefill_inputs(self, req_dicts: List[Request], num_running_requests: ) self.share_inputs["seq_lens_decoder"][idx : idx + 1] = 0 + if request.get("enable_thinking", False) and request.get("reasoning_max_tokens", None) is not None: + # Enable thinking + self.share_inputs["max_think_lens"][idx : idx + 1, :] = request.get("reasoning_max_tokens") + self.share_inputs["limit_think_status"][idx : idx + 1, :] = 0 + else: + # Disable thinking + self.share_inputs["max_think_lens"][idx : idx + 1, :] = -1 + self.share_inputs["limit_think_status"][idx : idx + 1, :] = 0 + def get_attr_from_request(request, attr, default_value=None): res = request.get(attr, default_value) if res is not None: @@ -484,6 +595,12 @@ def get_attr_from_request(request, attr, default_value=None): self.share_inputs["presence_score"][idx : idx + 1] = get_attr_from_request( request, "presence_penalty", 0.0 ) + self.share_inputs["temp_scaled_logprobs"][idx : idx + 1] = get_attr_from_request( + request, "temp_scaled_logprobs", False + ) + self.share_inputs["top_p_normalized_logprobs"][idx : idx + 1] = get_attr_from_request( + request, "top_p_normalized_logprobs", False + ) self.share_inputs["min_dec_len"][idx : idx + 1] = request.get("min_tokens", 1) self.share_inputs["max_dec_len"][idx : idx + 1] = request.get( @@ -535,27 +652,100 @@ def get_attr_from_request(request, attr, default_value=None): if self.speculative_method in ["mtp"]: self.proposer.insert_prefill_inputs(req_dicts, num_running_requests) - def _dummy_prefill_inputs(self, num_tokens: int, batch_size: int, expected_decode_len: int): - """Set dummy prefill inputs to share_inputs""" + def get_input_length_list( + self, num_tokens: int, batch_size: int, expected_decode_len: int, capture_prefill: bool = False + ): + """ + Generates some list for _dummy_prefill_inputs, when capture pure prefill or mtp, + the list should be carefully constructed. + + This function addresses a specific problem: in the pure prefill stage, variable + input lengths (e.g., `prompt[160, 0]` vs. `prompt[80, 80]`) can lead to different + CUDA Grid dimensions for kernels like `split_q_block`. This prevents CUDA Graph + reuse. + + The `split_q_block` kernel calculates the total number of blocks, which directly + determines the `griddim.x` launch parameter for the `multi_query_append_attention_kernel`. + The blocks for a single sequence are determined by the formula: + `num_blocks = ceil((sequence_length * group_size) / block_shape_q)` + + Due to the `ceil` (ceiling) function, distributing a total number of tokens across + a batch of shorter sequences will result in a larger total block count. For example, + with a `group_size` of 5 and `block_shape_q` of 64: + - A single sequence of 160 tokens requires `ceil((160 * 5) / 64) = 13` blocks. + - Two sequences of 80 tokens each require `ceil((80 * 5) / 64) * 2 = 7 * 2 = 14` blocks. + + To ensure graph replayability, this function creates a "dummy" list of sequence + lengths that's designed to produce the theoretical maximum `encoder_num_blocks_x_cpu` + for the given `num_tokens` and `batch_size`. This strategy ensures the captured + CUDA Graph has the largest possible grid dimensions. At runtime, if the actual number + of blocks is less than or equal to this maximum, the kernel can safely execute by + using an early-exit mechanism. + + Args: + num_tokens (int): The total number of tokens across all sequences. + batch_size (int): The number of sequences (requests) in the batch. + + Returns: + List[int]: A list of integers representing the sequence length for each request. + This list is crafted to maximize the total number of blocks. + """ # NOTE(gongshaotian): The maximum decoding length is equal to the expected decoded tokens plus the eos token max_dec_len = expected_decode_len + 1 - full_length = min( - num_tokens // batch_size, + input_length = min( + num_tokens // (1 if capture_prefill else batch_size), self.model_config.max_model_len - max_dec_len, ) - # When the full length is too large, DeepEP's buffer size will not be enough to cause the result to appear nan. - # Figure out the accurate buffer size of DeepEP. + # NOTE(wanglongzhi): When the full length is too large, DeepEP's buffer size will not be enough to cause the result to appear nan. + # TODO(wanglongzhi): Figure out the accurate buffer size of DeepEP. if self.fd_config.parallel_config.enable_expert_parallel: - full_length = min(full_length, 32) + input_length = min(input_length, 32) - input_length = int(full_length * self.cache_config.kv_cache_ratio) block_num = ( input_length + self.cache_config.block_size - 1 ) // self.cache_config.block_size + self.cache_config.enc_dec_block_num + input_length_list = [input_length] * batch_size + + if capture_prefill: + if num_tokens < batch_size: + input_length_list = [1] * num_tokens + else: + input_length_list = [1] * (batch_size - 1) + input_length_list.append(num_tokens - batch_size + 1) + + len_of_input_length_list = len(input_length_list) + max_dec_len_list = [max_dec_len] * len_of_input_length_list + + return input_length_list, max_dec_len_list, block_num + + def get_supported_pooling_tasks(self) -> list[PoolingTask]: + model = self.get_model() + if not self.is_pooling_model: + return [] + + supported_tasks = list(model.pooler.get_supported_tasks()) + + if self.cache_config.enable_chunked_prefill and "encode" in supported_tasks: + supported_tasks.remove("encode") + + logger.warning( + "Chunked prefill is not supported with " + "encode task which using ALL pooling. " + "Please turn off chunked prefill by export=FD_DISABLE_CHUNKED_PREFILL=1 before using it." + ) + + # score not support + return supported_tasks + + def _dummy_prefill_inputs(self, input_length_list: List[int], max_dec_len_list: List[int], block_num: int): + """Set dummy prefill inputs to share_inputs""" + batch_size = len(input_length_list) for i in range(batch_size): idx = i + input_length = input_length_list[i] + max_dec_len = max_dec_len_list[i] self.share_inputs["input_ids"][idx : idx + 1, :input_length] = np.array([5] * input_length) self.share_inputs["prompt_ids"][idx : idx + 1, :input_length] = np.array([5] * input_length) self.share_inputs["eos_token_id"][:] = np.array( @@ -623,6 +813,8 @@ def _init_share_inputs(self, max_num_seqs: int): self.share_inputs["presence_score"] = paddle.full( [max_num_seqs, 1], self.model_config.presence_score, dtype="float32" ) + self.share_inputs["temp_scaled_logprobs"] = paddle.full([max_num_seqs, 1], False, dtype="bool") + self.share_inputs["top_p_normalized_logprobs"] = paddle.full([max_num_seqs, 1], False, dtype="bool") self.share_inputs["min_dec_len"] = paddle.full([max_num_seqs, 1], self.model_config.min_length, dtype="int64") self.share_inputs["max_dec_len"] = paddle.full( @@ -674,18 +866,28 @@ def _init_share_inputs(self, max_num_seqs: int): self.share_inputs["decoder_batch_ids"] = None self.share_inputs["decoder_tile_ids_per_batch"] = None self.share_inputs["decoder_num_blocks_cpu"] = None # Pinning Memory + self.share_inputs["decoder_num_blocks_device"] = None + self.share_inputs["decoder_chunk_size_device"] = None self.share_inputs["max_len_tensor_cpu"] = None # CPU + self.share_inputs["encoder_batch_ids"] = None + self.share_inputs["encoder_tile_ids_per_batch"] = None + self.share_inputs["encoder_num_blocks_x_cpu"] = None # CPU + self.share_inputs["kv_batch_ids"] = None + self.share_inputs["kv_tile_ids_per_batch"] = None + self.share_inputs["kv_num_blocks_x_cpu"] = None # CPU - # Initialize rotary position embedding - tmp_position_ids = paddle.arange(self.model_config.max_model_len).reshape((1, -1)) + # Initialize thinking related buffers + self.share_inputs["max_think_lens"] = paddle.full(shape=[max_num_seqs, 1], fill_value=-1, dtype="int32") + self.share_inputs["limit_think_status"] = paddle.full(shape=[max_num_seqs, 1], fill_value=0, dtype="int32") - # TODO(gongshaotian): move to models + # Initialize rotary position embedding if not self.enable_mm: self.share_inputs["rope_emb"] = get_rope( rotary_dim=self.model_config.head_dim, - position_ids=tmp_position_ids, + position_ids=paddle.arange(self.model_config.max_model_len).reshape((1, -1)), base=self.model_config.rope_theta, model_config=self.model_config, + partial_rotary_factor=self.model_config.partial_rotary_factor, ) # Set block tables @@ -749,9 +951,30 @@ def _init_share_inputs(self, max_num_seqs: int): fill_value=0, dtype="int32", ) + # For V1_KVCACHE_SCHEDULER + self.share_inputs["step_draft_tokens"] = paddle.full( + shape=[max_num_seqs, max_draft_token_num + 1], + fill_value=0, + dtype="int64", + ) + self.share_inputs["step_seq_lens_this_time"] = paddle.full([max_num_seqs, 1], 0, dtype="int32") + # For MTP Logprob + self.share_inputs["draft_logits"] = paddle.full( + [max_num_seqs * (self.speculative_config.num_speculative_tokens + 1), self.model_config.vocab_size], + -1, + dtype="float32", + ) + self.share_inputs["cu_batch_token_offset"] = paddle.full( + shape=[max_num_seqs + 1], fill_value=0, dtype="int32" + ) if self.enable_mm: head_dim = self.model_config.head_dim + if "qwen" in self.model_config.model_type: # neox style = True + rope_head_dim = head_dim + else: # neox style = False + rope_head_dim = head_dim // 2 + self.share_inputs["rope_emb"] = paddle.full( shape=[ max_num_seqs, @@ -759,7 +982,7 @@ def _init_share_inputs(self, max_num_seqs: int): 1, self.model_config.max_model_len, 1, - head_dim // 2, + rope_head_dim, ], fill_value=0, dtype="float32", @@ -777,7 +1000,11 @@ def _prepare_inputs(self) -> None: self.share_inputs["step_seq_lens_decoder"], self.share_inputs["block_tables"], self.share_inputs["is_block_step"], + self.share_inputs["draft_tokens"] if self.speculative_decoding else None, + self.share_inputs["step_draft_tokens"] if self.speculative_decoding else None, + self.share_inputs["step_seq_lens_this_time"] if self.speculative_decoding else None, self.cache_config.block_size, + self.speculative_config.num_speculative_tokens if self.speculative_decoding else 0, ) # Remove padding @@ -798,7 +1025,8 @@ def _prepare_inputs(self) -> None: ) self.share_inputs["ids_remove_padding"].copy_(ids_remove_padding, False) - self.share_inputs["batch_id_per_token"].copy_(batch_id_per_token, False) + # NOTE: (changwenbin) Initialized to max_num_seq '-1' before copying, marking illegal positions + self.share_inputs["batch_id_per_token"][:] = -1 self.share_inputs["cu_seqlens_q"].copy_(cu_seqlens_q, False) self.share_inputs["cu_seqlens_k"].copy_(cu_seqlens_k, False) @@ -812,6 +1040,7 @@ def _prepare_inputs(self) -> None: # Initialize forward meta data self.initialize_forward_meta() + self.forward_meta.batch_id_per_token.copy_(batch_id_per_token, False) # Get sampling metadata self.sampling_metadata = SamplingMetadata( @@ -835,6 +1064,9 @@ def _prepare_inputs(self) -> None: max_num_logprobs=20 if self.enable_logprob else None, enable_early_stop=self.enable_early_stop, stop_flags=self.share_inputs["stop_flags"], + temp_scaled_logprobs=self.share_inputs["temp_scaled_logprobs"], + top_p_normalized_logprobs=self.share_inputs["top_p_normalized_logprobs"], + share_inputs=self.share_inputs, ) def load_model(self) -> None: @@ -873,6 +1105,10 @@ def initialize_forward_meta(self): decoder_batch_ids=self.share_inputs["decoder_batch_ids"], decoder_tile_ids_per_batch=self.share_inputs["decoder_tile_ids_per_batch"], decoder_num_blocks_cpu=self.share_inputs["decoder_num_blocks_cpu"], + # NOTE: (changwenbin) MLA kernel only needs decoder_num_blocks_device in place of GPU tensor, + # adapted to cudagraph. + decoder_num_blocks_device=self.share_inputs["decoder_num_blocks_device"], + decoder_chunk_size_device=self.share_inputs["decoder_chunk_size_device"], max_len_tensor_cpu=self.share_inputs["max_len_tensor_cpu"], seq_lens_encoder=self.share_inputs["seq_lens_encoder"], seq_lens_decoder=self.share_inputs["seq_lens_decoder"], @@ -882,23 +1118,29 @@ def initialize_forward_meta(self): cu_seqlens_k=self.share_inputs["cu_seqlens_k"], block_tables=self.share_inputs["block_tables"], caches=self.share_inputs["caches"], + encoder_batch_ids=self.share_inputs["encoder_batch_ids"], + encoder_tile_ids_per_batch=self.share_inputs["encoder_tile_ids_per_batch"], + encoder_num_blocks_x_cpu=self.share_inputs["encoder_num_blocks_x_cpu"], + kv_batch_ids=self.share_inputs["kv_batch_ids"], + kv_tile_ids_per_batch=self.share_inputs["kv_tile_ids_per_batch"], + kv_num_blocks_x_cpu=self.share_inputs["kv_num_blocks_x_cpu"], ) - # Update Batch type for cuda graph - only_decode_batch = True - prefill_exists = None - # mix ep in single node + # Update Batch type for cuda graph for only_decode_batch + if_only_decode = self.only_decode() + only_decode_use_cudagraph = self.use_cudagraph and if_only_decode + + # Update config about moe for better performance + # TODO(wanglongzhi):Modifying the config at runtime is not appropriate; it needs to be moved to forward_meta. It will be used in MoEMethodBase.apply() if self.fd_config.parallel_config.use_ep and self.fd_config.scheduler_config.splitwise_role == "mixed": - only_decode_batch_list = [] - prefill_exists = self.exist_prefill() - paddle.distributed.all_gather_object(only_decode_batch_list, not prefill_exists) - only_decode_batch = all(only_decode_batch_list) - self.fd_config.model_config.moe_phase.phase = "decode" if only_decode_batch else "prefill" + self.fd_config.model_config.moe_phase.phase = "decode" if if_only_decode else "prefill" + + # Update Batch type for cuda graph for only_prefill_batch + only_prefill_use_cudagraph = self.use_cudagraph and self.cudagraph_only_prefill and self.only_prefill() + # When support capture both prefill-only and decode-only, this will use [only_prefill_use_cudagraph or only_decode_use_cudagraph] self.forward_meta.step_use_cudagraph = ( - self.use_cudagraph - and only_decode_batch - and not (prefill_exists if prefill_exists is not None else self.exist_prefill()) + only_prefill_use_cudagraph if self.cudagraph_only_prefill else only_decode_use_cudagraph ) # Initialzie attention meta data @@ -909,7 +1151,7 @@ def initialize_kv_cache(self, profile: bool = False) -> None: """ Initialize kv cache """ - cache_kvs = {} + # cache_kvs = {} max_block_num = self.num_gpu_blocks # Get kv cache dtype @@ -928,42 +1170,83 @@ def initialize_kv_cache(self, profile: bool = False) -> None: kv_cache_shape = self.attn_backends[0].get_kv_cache_shape( max_num_blocks=max_block_num, kv_cache_quant_type=kv_cache_quant_type ) + if kv_cache_quant_type == "block_wise_fp8": + kv_cache_scale_shape = [kv_cache_shape[0], kv_cache_shape[1], kv_cache_shape[2]] local_rank = self.local_rank % self.parallel_config.tensor_parallel_size - if not profile and ( - self.cache_config.enable_prefix_caching or self.scheduler_config.splitwise_role != "mixed" - ): - cache_kvs_list = [] - for i in range(self.model_config.num_hidden_layers): - key_cache = paddle.empty(shape=[], dtype=cache_type) - key_cache_name = f"key_caches_{i}_rank{local_rank}.device{self.device_id}" + cache_ready_signal_data = np.zeros(shape=[self.parallel_config.tensor_parallel_size], dtype=np.int32) + cache_ready_signal = IPCSignal( + name="cache_ready_signal", + array=cache_ready_signal_data, + dtype=np.int32, + suffix=self.parallel_config.engine_worker_queue_port, + create=False, + ) + + # Check if gpu runner needs to create kv cache + # 1. During profiling, it creates its own kv cache. + # 2. GPU runner creates kv cache tensor unless p/d disaggregation is enabled. + create_cache_tensor = profile or self.scheduler_config.splitwise_role == "mixed" + + if not create_cache_tensor: + logger.info(f"Waiting for cache managers to create kv cache.. {cache_ready_signal.value}") + while cache_ready_signal.value[local_rank] != 1: + time.sleep(1) + logger.info(f"OK! Stop waiting. {cache_ready_signal.value}") + + logger.info(f"Initializing kv cache for all layers. {cache_ready_signal.value}") + cache_kvs_list = [] + + # NOTE:(changwenbin) Determine whether it is Multi-Head Latent Attention, + # To rationalize the allocation of kvcache. + from fastdeploy import envs + + self.mla_cache = envs.FD_ATTENTION_BACKEND == "MLA_ATTN" + for i in range(self.model_config.num_hidden_layers): + key_cache_name = f"key_caches_{i}_rank{local_rank}.device{self.device_id}" + if not self.mla_cache: val_cache_name = f"value_caches_{i}_rank{local_rank}.device{self.device_id}" + if create_cache_tensor: + logger.info(f"..creating kv cache for layer {i}: {kv_cache_shape}") + key_cache = paddle.full(shape=kv_cache_shape, fill_value=0, dtype=cache_type) + set_data_ipc(key_cache, key_cache_name) + if not self.mla_cache: + val_cache = paddle.full(shape=kv_cache_shape, fill_value=0, dtype=cache_type) + set_data_ipc(val_cache, val_cache_name) + cache_kvs_list.extend([key_cache, val_cache]) + else: + cache_kvs_list.extend([key_cache]) + if kv_cache_quant_type == "block_wise_fp8": + key_cache_scales = paddle.full( + shape=kv_cache_scale_shape, fill_value=0, dtype=paddle.get_default_dtype() + ) + if not self.mla_cache: + val_cache_scales = paddle.full( + shape=kv_cache_scale_shape, fill_value=0, dtype=paddle.get_default_dtype() + ) + cache_kvs_list.extend([key_cache_scales, val_cache_scales]) + else: + cache_kvs_list.extend([key_cache_scales]) + else: + logger.info(f"..attaching kv cache for layer {i}: {kv_cache_shape}") + key_cache = paddle.empty(shape=[], dtype=cache_type) key_cache = share_external_data(key_cache, key_cache_name, kv_cache_shape) - cache_kvs_list.append(key_cache) - value_cache = paddle.empty(shape=[], dtype=cache_type) - value_cache = share_external_data(value_cache, val_cache_name, kv_cache_shape) - cache_kvs_list.append(value_cache) + if not self.mla_cache: + val_cache = paddle.empty(shape=[], dtype=cache_type) + val_cache = share_external_data(val_cache, val_cache_name, kv_cache_shape) + cache_kvs_list.extend([key_cache, val_cache]) + else: + cache_kvs_list.extend([key_cache]) - self.share_inputs["caches"] = cache_kvs_list + self.share_inputs["caches"] = cache_kvs_list - else: - for i in range(self.model_config.num_hidden_layers): - cache_kvs[f"key_caches_{i}"] = paddle.full( - shape=kv_cache_shape, - fill_value=0, - dtype=cache_type, - ) - cache_kvs[f"value_caches_{i}"] = paddle.full( - shape=kv_cache_shape, - fill_value=0, - dtype=cache_type, - ) - self.share_inputs["caches"] = list(cache_kvs.values()) - for value in cache_kvs.values(): - del value - # paddle.device.empty_cache() + if not profile and create_cache_tensor: + cache_ready_signal.value[local_rank] = 1 + logger.info(f"✅ kv cache is ready! {cache_ready_signal.value}") - def initialize_attn_backend(self) -> None: + paddle.device.empty_cache() + + def _initialize_attn_backend(self) -> None: """ Initialize attention backends """ @@ -980,13 +1263,37 @@ def initialize_attn_backend(self) -> None: encoder_block_shape_q = 64 decoder_block_shape_q = 16 decoder_step_token_num = self.speculative_config.num_speculative_tokens + 1 - decode_max_tile_size = self.scheduler_config.max_num_seqs * np.ceil( - (decoder_step_token_num * np.ceil(num_heads / self.model_config.kv_num_heads)) / decoder_block_shape_q + group_size = np.ceil(num_heads / self.model_config.kv_num_heads) + + # NOTE: (changwenbin) When using auto_chunk, + # decode_max_tile_size must take into account the maximum case, where *1024 can cover 128K. + decode_max_tile_size = ( + 1024 + * self.scheduler_config.max_num_seqs + * np.ceil((decoder_step_token_num * group_size) / decoder_block_shape_q) + ) + encode_max_tile_size = self.scheduler_config.max_num_seqs * np.ceil( + (self.model_config.max_model_len * group_size) / encoder_block_shape_q + ) + kv_max_tile_size = self.scheduler_config.max_num_seqs * np.ceil( + self.model_config.max_model_len / self.fd_config.cache_config.block_size ) self.share_inputs["decoder_batch_ids"] = paddle.full([int(decode_max_tile_size)], 0, dtype="int32") self.share_inputs["decoder_tile_ids_per_batch"] = paddle.full([int(decode_max_tile_size)], 0, dtype="int32") - # self.share_inputs["decoder_num_blocks_cpu"] = paddle.full([1], 0, dtype="int32").pin_memory() - # self.share_inputs["max_len_tensor_cpu"] = paddle.full([8], 0, dtype="int32").cpu() + self.share_inputs["decoder_num_blocks_cpu"] = paddle.full([1], 0, dtype="int32").cpu() + # NOTE: (changwenbin) MLA kernel only needs decoder_num_blocks_device in place of GPU tensor, + # adapted to cudagraph. + self.share_inputs["decoder_num_blocks_device"] = paddle.full([1], 0, dtype="int32") + self.share_inputs["decoder_chunk_size_device"] = paddle.full([1], 64, dtype="int32") + self.share_inputs["max_len_tensor_cpu"] = paddle.full([9], 0, dtype="int32").cpu() + + self.share_inputs["encoder_batch_ids"] = paddle.full([int(encode_max_tile_size)], 0, dtype="int32") + self.share_inputs["encoder_tile_ids_per_batch"] = paddle.full([int(encode_max_tile_size)], 0, dtype="int32") + self.share_inputs["encoder_num_blocks_x_cpu"] = paddle.full([1], 0, dtype="int32").cpu() + + self.share_inputs["kv_batch_ids"] = paddle.full([int(kv_max_tile_size)], 0, dtype="int32") + self.share_inputs["kv_tile_ids_per_batch"] = paddle.full([int(kv_max_tile_size)], 0, dtype="int32") + self.share_inputs["kv_num_blocks_x_cpu"] = paddle.full([1], 0, dtype="int32").cpu() # Get the attention backend attn_cls = get_attention_backend() @@ -1001,12 +1308,178 @@ def initialize_attn_backend(self) -> None: self.attn_backends.append(attn_backend) + def _dummy_pooler_run_task( + self, + hidden_states: paddle.Tensor, + task: PoolingTask, + ) -> PoolerOutput: + num_tokens = hidden_states.shape[0] + max_num_seqs = self.scheduler_config.max_num_seqs + num_reqs = min(num_tokens, max_num_seqs) + min_tokens_per_req = num_tokens // num_reqs + num_scheduled_tokens_list = [min_tokens_per_req] * num_reqs + num_scheduled_tokens_list[-1] += num_tokens % num_reqs + assert sum(num_scheduled_tokens_list) == num_tokens + assert len(num_scheduled_tokens_list) == num_reqs + + req_num_tokens = num_tokens // num_reqs + + dummy_prompt_lens = paddle.to_tensor(num_scheduled_tokens_list, dtype="int64") + dummy_token_ids = paddle.zeros( + [num_reqs, req_num_tokens], + dtype="int64", + ) + model = cast(FdModelForPooling, self.get_model()) + dummy_pooling_params = PoolingParams(task=task) + to_update = model.pooler.get_pooling_updates(task) + to_update.apply(dummy_pooling_params) + + dummy_metadata = PoolingMetadata( + prompt_lens=dummy_prompt_lens, + prompt_token_ids=dummy_token_ids, + pooling_params=[dummy_pooling_params] * num_reqs, + ) + dummy_metadata.build_pooling_cursor(num_scheduled_tokens_list, device=hidden_states.place) + + try: + return model.pooler(hidden_states=hidden_states, pooling_metadata=dummy_metadata) + except RuntimeError as e: + if "out of memory" in str(e): + raise RuntimeError( + "CUDA out of memory occurred when warming up pooler " + f"({task=}) with {num_reqs} dummy requests. Please try " + "lowering `max_num_seqs` or `gpu_memory_utilization` when " + "initializing the engine." + ) from e + else: + raise e + + def _dummy_pooler_run( + self, + hidden_states: paddle.Tensor, + ) -> PoolerOutput: + output_size = dict[PoolingTask, float]() + for task in self.get_supported_pooling_tasks(): + output = self._dummy_pooler_run_task(hidden_states, task) + output_size[task] = output.get_data_nbytes() + del output + + max_task = max(output_size.items(), key=lambda x: x[1])[0] + final_output = self._dummy_pooler_run_task(hidden_states, max_task) + + return final_output + + def _dummy_sampler_run( + self, + hidden_states: paddle.Tensor, + model_output: paddle.Tensor, + ) -> paddle.Tensor: + logits = self.model.compute_logits(hidden_states) + + if not self.speculative_decoding: + set_value_by_flags_and_idx( + self.share_inputs["pre_ids"], + self.share_inputs["input_ids"], + self.share_inputs["seq_lens_this_time"], + self.share_inputs["seq_lens_encoder"], + self.share_inputs["seq_lens_decoder"], + self.share_inputs["step_idx"], + self.share_inputs["stop_flags"], + ) + sampler_output = self.sampler(logits, self.sampling_metadata) + if self.parallel_config.tensor_parallel_size > 1: + paddle.distributed.broadcast( + sampler_output.sampled_token_ids, + self.parallel_config.data_parallel_rank * self.parallel_config.tensor_parallel_size, + group=self.parallel_config.tp_group, + ) + else: + self.sampler( + logits, + self.sampling_metadata, + self.model_config.max_model_len, + self.share_inputs, + ) + sampler_output = None + if self.parallel_config.tensor_parallel_size > 1: + paddle.distributed.broadcast( + self.share_inputs["accept_tokens"], + self.parallel_config.data_parallel_rank * self.parallel_config.tensor_parallel_size, + group=self.parallel_config.tp_group, + ) + paddle.distributed.broadcast( + self.share_inputs["accept_num"], + self.parallel_config.data_parallel_rank * self.parallel_config.tensor_parallel_size, + group=self.parallel_config.tp_group, + ) + paddle.distributed.broadcast( + self.share_inputs["step_idx"], + self.parallel_config.data_parallel_rank * self.parallel_config.tensor_parallel_size, + group=self.parallel_config.tp_group, + ) + paddle.distributed.broadcast( + self.share_inputs["stop_flags"], + self.parallel_config.data_parallel_rank * self.parallel_config.tensor_parallel_size, + group=self.parallel_config.tp_group, + ) + # 5. post process + model_output_data = ModelOutputData( + next_tokens=self.share_inputs["next_tokens"], + stop_flags=self.share_inputs["stop_flags"], + step_idx=self.share_inputs["step_idx"], + max_dec_len=self.share_inputs["max_dec_len"], + pre_ids=self.share_inputs["pre_ids"], + seq_lens_this_time=self.share_inputs["seq_lens_this_time"], + eos_token_id=self.share_inputs["eos_token_id"], + not_need_stop=self.share_inputs["not_need_stop"], + input_ids=self.share_inputs["input_ids"], + stop_nums=self.share_inputs["stop_nums"], + seq_lens_encoder=self.share_inputs["seq_lens_encoder"], + seq_lens_decoder=self.share_inputs["seq_lens_decoder"], + is_block_step=self.share_inputs["is_block_step"], + full_hidden_states=model_output, + msg_queue_id=self.parallel_config.msg_queue_id, + mp_rank=self.parallel_config.tensor_parallel_rank, + use_ep=self.parallel_config.use_ep, + draft_tokens=(self.share_inputs["draft_tokens"] if self.speculative_decoding else None), + actual_draft_token_num=( + self.share_inputs["actual_draft_token_num"] if self.speculative_decoding else None + ), + accept_tokens=(self.share_inputs["accept_tokens"] if self.speculative_decoding else None), + accept_num=(self.share_inputs["accept_num"] if self.speculative_decoding else None), + stop_token_ids=self.share_inputs["stop_seqs"], + stop_seqs_len=self.share_inputs["stop_seqs_len"], + ) + + post_process( + sampler_output=sampler_output, + model_output=model_output_data, + share_inputs=self.share_inputs, + block_size=self.cache_config.block_size, + speculative_decoding=self.speculative_decoding, + skip_save_output=True, + async_output_queue=self.async_output_queue, + think_end_id=self.model_config.think_end_id, + line_break_id=self.model_config.line_break_id, + ) + if self.speculative_decoding: + if self.speculative_method == "mtp": + self.proposer.run( + full_hidden_states=model_output, step_use_cudagraph=self.forward_meta.step_use_cudagraph + ) + else: + self.proposer.run(share_inputs=self.share_inputs) + + return sampler_output + def _dummy_run( self, num_tokens: paddle.Tensor, batch_size: paddle.Tensor, expected_decode_len: int = 1, in_capturing: bool = False, + capture_prefill: bool = False, + accept_all_drafts: bool = False, ) -> paddle.Tensor: """ Use dummy inputs to run before formal execution. @@ -1014,11 +1487,20 @@ def _dummy_run( num_tokens: expected_decode_len: Expected number of tokens generated in_capturing: Is cuda graph in capturing state + capture_prefill: Capture pure prefill for cuda graph + accept_all_drafts: Target model will accept all draft tokens """ - self._dummy_prefill_inputs( + + input_length_list, max_dec_len_list, block_num = self.get_input_length_list( num_tokens=num_tokens, batch_size=batch_size, expected_decode_len=expected_decode_len, + capture_prefill=capture_prefill, + ) + self._dummy_prefill_inputs( + input_length_list=input_length_list, + max_dec_len_list=max_dec_len_list, + block_num=block_num, ) if self.speculative_method in ["mtp"]: self.proposer.dummy_prefill_inputs( @@ -1026,8 +1508,8 @@ def _dummy_run( batch_size=batch_size, expected_decode_len=expected_decode_len, ) - while True: + while True: # 1. Initialize forward meta and attention meta data self._prepare_inputs() @@ -1042,98 +1524,31 @@ def _dummy_run( self.share_inputs["image_features"], self.forward_meta, ) - hidden_states = model_output else: model_output = self.model( ids_remove_padding=self.share_inputs["ids_remove_padding"], forward_meta=self.forward_meta, ) + if self.use_cudagraph: + model_output = model_output[: self.real_token_num] - hidden_states = rebuild_padding( - model_output, - self.share_inputs["cu_seqlens_q"], - self.share_inputs["seq_lens_this_time"], - self.share_inputs["seq_lens_decoder"], - self.share_inputs["seq_lens_encoder"], - ( - self.share_inputs["output_padding_offset"] if self.speculative_decoding else None - ), # speculative decoding requires - self.model_config.max_model_len, - ) - - # 4. Execute spec decode - logits = self.model.compute_logits(hidden_states) - - if not self.speculative_decoding: - set_value_by_flags_and_idx( - self.share_inputs["pre_ids"], - self.share_inputs["input_ids"], - self.share_inputs["seq_lens_this_time"], - self.share_inputs["seq_lens_encoder"], - self.share_inputs["seq_lens_decoder"], - self.share_inputs["step_idx"], - self.share_inputs["stop_flags"], - ) - sampler_output = self.sampler(logits, self.sampling_metadata) - if self.parallel_config.tensor_parallel_size > 1: - paddle.distributed.broadcast(sampler_output.sampled_token_ids, 0) - else: - self.sampler( - logits, - self.sampling_metadata, - self.model_config.max_model_len, - self.share_inputs, - ) - sampler_output = None - if self.parallel_config.tensor_parallel_size > 1: - paddle.distributed.broadcast(self.share_inputs["accept_tokens"], 0) - paddle.distributed.broadcast(self.share_inputs["accept_num"], 0) - paddle.distributed.broadcast(self.share_inputs["step_idx"], 0) - paddle.distributed.broadcast(self.share_inputs["stop_flags"], 0) - - # 5. post process - model_output_data = ModelOutputData( - next_tokens=self.share_inputs["next_tokens"], - stop_flags=self.share_inputs["stop_flags"], - step_idx=self.share_inputs["step_idx"], - max_dec_len=self.share_inputs["max_dec_len"], - pre_ids=self.share_inputs["pre_ids"], - seq_lens_this_time=self.share_inputs["seq_lens_this_time"], - eos_token_id=self.share_inputs["eos_token_id"], - not_need_stop=self.share_inputs["not_need_stop"], - input_ids=self.share_inputs["input_ids"], - stop_nums=self.share_inputs["stop_nums"], - seq_lens_encoder=self.share_inputs["seq_lens_encoder"], - seq_lens_decoder=self.share_inputs["seq_lens_decoder"], - is_block_step=self.share_inputs["is_block_step"], - full_hidden_states=model_output, - msg_queue_id=self.parallel_config.msg_queue_id, - mp_rank=self.local_rank, - use_ep=self.parallel_config.use_ep, - draft_tokens=(self.share_inputs["draft_tokens"] if self.speculative_decoding else None), - actual_draft_token_num=( - self.share_inputs["actual_draft_token_num"] if self.speculative_decoding else None - ), - accept_tokens=(self.share_inputs["accept_tokens"] if self.speculative_decoding else None), - accept_num=(self.share_inputs["accept_num"] if self.speculative_decoding else None), - stop_token_ids=self.share_inputs["stop_seqs"], - stop_seqs_len=self.share_inputs["stop_seqs_len"], - ) - - post_process( - sampler_output=sampler_output, - model_output=model_output_data, - share_inputs=self.share_inputs, - block_size=self.cache_config.block_size, - speculative_decoding=self.speculative_decoding, - skip_save_output=True, + hidden_states = rebuild_padding( + model_output, + self.share_inputs["cu_seqlens_q"], + self.share_inputs["seq_lens_this_time"], + self.share_inputs["seq_lens_decoder"], + self.share_inputs["seq_lens_encoder"], + ( + self.share_inputs["output_padding_offset"] if self.speculative_decoding else None + ), # speculative decoding requires + self.model_config.max_model_len, ) - if self.speculative_decoding: - if self.speculative_method == "mtp": - self.proposer.run(full_hidden_states=model_output) - else: - self.proposer.run(share_inputs=self.share_inputs) + if self.is_pooling_model: + self._dummy_pooler_run(hidden_states) + break + else: + self._dummy_sampler_run(hidden_states, model_output) # 7. Updata 'infer_seed' and step_cuda() self.share_inputs["infer_seed"].add_(self.infer_seed_increment) @@ -1145,7 +1560,6 @@ def _dummy_run( self.speculative_config, self.cache_config.enable_prefix_caching, ) - if int((self.share_inputs["seq_lens_this_time"] > 0).sum()) == 0: break @@ -1155,13 +1569,15 @@ def _update_chunked_prefill(self, tasks): """ if not self.cache_config.enable_chunked_prefill: return - for task in tasks: - if task.get("prefill_chunk_info", None) is None: - continue - if task.chunk_idx > len(task.prefill_chunk_info): - continue - self.restore_chunked_prefill_request[task.request_id] = task + if tasks is not None: + for task in tasks: + if task.get("prefill_chunk_info", None) is None: + continue + + if task.chunk_idx > len(task.prefill_chunk_info): + continue + self.restore_chunked_prefill_request[task.request_id] = task for id, task in list(self.restore_chunked_prefill_request.items()): idx = task.idx @@ -1211,6 +1627,7 @@ def _update_chunked_prefill(self, tasks): self.proposer.update_task_chunk_prefill(task) task.chunk_idx += 1 + @sot_warmup_guard(True) def capture_model(self) -> None: """ Trigger CUDA Graph capture for all shapes in cuda graph capture list @@ -1221,14 +1638,67 @@ def capture_model(self) -> None: time_before_capture = time.perf_counter() expected_decode_len = 1 capture_sizes = self.cudagraph_capture_sizes.copy() - for batch_size in sorted(capture_sizes, reverse=True): - self._dummy_run( - num_tokens=self.scheduler_config.max_num_batched_tokens, - batch_size=batch_size, - in_capturing=True, - expected_decode_len=expected_decode_len, - ) - logger.info(f"Warm up the model with the batch size:{batch_size}, num tokens:{expected_decode_len}") + if self.fd_config.graph_opt_config.cudagraph_only_prefill: + for num_tokens in sorted(capture_sizes, reverse=True): + self._dummy_run( + num_tokens=num_tokens, + batch_size=self.scheduler_config.max_num_seqs, + in_capturing=True, + expected_decode_len=expected_decode_len, + capture_prefill=True, + ) + logger.info( + f"Warm up the model with the num_tokens:{num_tokens}, expected_decode_len:{expected_decode_len}" + ) + elif self.speculative_decoding and self.speculative_method == "mtp": + # Capture Target Model without bsz 1 + for batch_size in sorted(capture_sizes, reverse=True): + if batch_size == 1: + logger.info("Skip token_num = 1, when capture target model for mtp") + else: + assert batch_size % 2 == 0 + self._dummy_run( + num_tokens=self.scheduler_config.max_num_batched_tokens, + batch_size=int(batch_size / 2), + in_capturing=True, + expected_decode_len=1, + ) + logger.info(f"Warm up the Target model with the num_tokens:{batch_size}, expected_decode_len:{1}") + # Capture Draft Model without bsz 1 + # NOTE(liujundong): expected_decode_len = 1, will affect mtp capture in cudagraph + for batch_size in sorted(capture_sizes, reverse=True): + if batch_size == 1: + logger.info("Skip token_num = 1, when capture Draft model for mtp") + else: + assert batch_size % 2 == 0 + self._dummy_run( + num_tokens=self.scheduler_config.max_num_batched_tokens, + batch_size=int(batch_size / 2), + in_capturing=True, + expected_decode_len=3, + accept_all_drafts=True, + ) + logger.info(f"Warm up the Draft model with the num_tokens:{batch_size}, expected_decode_len:{3}") + # Capture Draft Model with bsz 1 + if 1 in capture_sizes: + self._dummy_run( + num_tokens=self.scheduler_config.max_num_batched_tokens, + batch_size=int(1), + in_capturing=True, + expected_decode_len=3, + accept_all_drafts=False, + ) + logger.info(f"Warm up the Draft model with the num_tokens:{batch_size}, expected_decode_len:{3}") + + else: + for batch_size in sorted(capture_sizes, reverse=True): + self._dummy_run( + num_tokens=self.scheduler_config.max_num_batched_tokens, + batch_size=batch_size, + in_capturing=True, + expected_decode_len=expected_decode_len, + ) + logger.info(f"Warm up the model with the batch size:{batch_size}, num tokens:{expected_decode_len}") time_after_capture = time.perf_counter() logger.info(f"Cuda Graph capturing took {time_after_capture - time_before_capture} seconds") @@ -1252,10 +1722,15 @@ def _get_skip_idx(self, model_forward_batch: Optional[List[Request]] = None): Returns: A list of indices corresponding to the requests that need to be skipped. """ - skip_idx_list = [] - if not self.cache_config.enable_chunked_prefill or self.guided_backend is None: - return skip_idx_list + if ( + not self.cache_config.enable_chunked_prefill + or self.guided_backend is None + or model_forward_batch is None + or envs.ENABLE_V1_KVCACHE_SCHEDULER + ): + return [] + skip_idx_list = [] for task in model_forward_batch: if task.get("prefill_chunk_info", None) is None or task.chunk_idx >= len(task.prefill_chunk_info): continue @@ -1304,24 +1779,30 @@ class at the server level, which is too granular for ModelRunner. self.share_inputs["image_features"], self.forward_meta, ) - hidden_states = model_output else: model_output = self.model( ids_remove_padding=self.share_inputs["ids_remove_padding"], forward_meta=self.forward_meta, ) - hidden_states = rebuild_padding( - model_output, - self.share_inputs["cu_seqlens_q"], - self.share_inputs["seq_lens_this_time"], - self.share_inputs["seq_lens_decoder"], - self.share_inputs["seq_lens_encoder"], - (self.share_inputs["output_padding_offset"] if self.speculative_decoding else None), - self.model_config.max_model_len, - ) + if self.use_cudagraph: + model_output = model_output[: self.real_token_num] + hidden_states = rebuild_padding( + model_output, + self.share_inputs["cu_seqlens_q"], + self.share_inputs["seq_lens_this_time"], + self.share_inputs["seq_lens_decoder"], + self.share_inputs["seq_lens_encoder"], + (self.share_inputs["output_padding_offset"] if self.speculative_decoding else None), + self.model_config.max_model_len, + ) + logits = None # 4. Compute logits, Sample - logits = self.model.compute_logits(hidden_states) + if hasattr(self.model, "is_pooling_model") and self.model.is_pooling_model: + # TODO(lizexu123) The execution of the pooling function have not been implemented yet. + pass + else: + logits = self.model.compute_logits(hidden_states) if not self.speculative_decoding: set_value_by_flags_and_idx( @@ -1339,21 +1820,40 @@ class at the server level, which is too granular for ModelRunner. skip_idx_list, ) if self.parallel_config.tensor_parallel_size > 1: - paddle.distributed.broadcast(sampler_output.sampled_token_ids, 0) + paddle.distributed.broadcast( + sampler_output.sampled_token_ids, + self.parallel_config.data_parallel_rank * self.parallel_config.tensor_parallel_size, + group=self.parallel_config.tp_group, + ) else: - self.sampler( + sampler_output = self.sampler( logits, self.sampling_metadata, self.model_config.max_model_len, self.share_inputs, ) - sampler_output = None if self.parallel_config.tensor_parallel_size > 1: - paddle.distributed.broadcast(self.share_inputs["accept_tokens"], 0) - paddle.distributed.broadcast(self.share_inputs["accept_num"], 0) - paddle.distributed.broadcast(self.share_inputs["step_idx"], 0) - paddle.distributed.broadcast(self.share_inputs["stop_flags"], 0) + paddle.distributed.broadcast( + self.share_inputs["accept_tokens"], + self.parallel_config.data_parallel_rank * self.parallel_config.tensor_parallel_size, + group=self.parallel_config.tp_group, + ) + paddle.distributed.broadcast( + self.share_inputs["accept_num"], + self.parallel_config.data_parallel_rank * self.parallel_config.tensor_parallel_size, + group=self.parallel_config.tp_group, + ) + paddle.distributed.broadcast( + self.share_inputs["step_idx"], + self.parallel_config.data_parallel_rank * self.parallel_config.tensor_parallel_size, + group=self.parallel_config.tp_group, + ) + paddle.distributed.broadcast( + self.share_inputs["stop_flags"], + self.parallel_config.data_parallel_rank * self.parallel_config.tensor_parallel_size, + group=self.parallel_config.tp_group, + ) # 5. Post Process model_output_data = ModelOutputData( @@ -1372,7 +1872,7 @@ class at the server level, which is too granular for ModelRunner. is_block_step=self.share_inputs["is_block_step"], full_hidden_states=model_output, msg_queue_id=self.parallel_config.msg_queue_id, - mp_rank=self.local_rank, + mp_rank=self.parallel_config.tensor_parallel_rank, use_ep=self.parallel_config.use_ep, draft_tokens=(self.share_inputs["draft_tokens"] if self.speculative_decoding else None), actual_draft_token_num=( @@ -1382,6 +1882,7 @@ class at the server level, which is too granular for ModelRunner. accept_num=(self.share_inputs["accept_num"] if self.speculative_decoding else None), stop_token_ids=self.share_inputs["stop_seqs"], stop_seqs_len=self.share_inputs["stop_seqs_len"], + prompt_lens=self.share_inputs["prompt_lens"], ) if self.speculative_config.method in ["mtp"] and self.scheduler_config.splitwise_role == "prefill": @@ -1396,16 +1897,23 @@ class at the server level, which is too granular for ModelRunner. save_each_rank=self.parallel_config.use_ep, speculative_decoding=self.speculative_decoding, skip_save_output=skip_save_output, + async_output_queue=self.async_output_queue, + think_end_id=self.model_config.think_end_id, + line_break_id=self.model_config.line_break_id, ) + if self.guided_backend is not None and sampler_output is not None: + self.sampler.post_process(sampler_output.sampled_token_ids, skip_idx_list) # 6. Speculative decode if self.speculative_decoding: if self.speculative_method == "mtp": - self.proposer.run(full_hidden_states=model_output) + self.proposer.run( + full_hidden_states=model_output, step_use_cudagraph=self.forward_meta.step_use_cudagraph + ) else: self.proposer.run(share_inputs=self.share_inputs) - # 7. Updata 'infer_seed' and step_cuda() + # 7. Update 'infer_seed' and step_cuda() self.share_inputs["infer_seed"].add_(self.infer_seed_increment) self.share_inputs["infer_seed"][:] %= self.MAX_INFER_SEED if not envs.ENABLE_V1_KVCACHE_SCHEDULER: @@ -1419,6 +1927,27 @@ class at the server level, which is too granular for ModelRunner. self._update_chunked_prefill(model_forward_batch) self._add_cache(model_forward_batch) + elif self.speculative_decoding: + speculate_schedule_cache( + self.share_inputs["draft_tokens"], + self.share_inputs["block_tables"], + self.share_inputs["stop_flags"], + self.share_inputs["prompt_lens"], + self.share_inputs["seq_lens_this_time"], + self.share_inputs["seq_lens_encoder"], + self.share_inputs["seq_lens_decoder"], + self.share_inputs["step_seq_lens_decoder"], + self.share_inputs["step_draft_tokens"], + self.share_inputs["step_seq_lens_this_time"], + self.share_inputs["accept_num"], + self.share_inputs["accept_tokens"], + self.share_inputs["is_block_step"], + self.share_inputs["not_need_stop"], + self.share_inputs["stop_nums"], + self.cache_config.block_size, + self.speculative_config.num_speculative_tokens, + ) + self.seq_lens_this_time_buffer[:num_running_requests].copy_( self.share_inputs["seq_lens_this_time"][:num_running_requests], False ) @@ -1428,7 +1957,7 @@ def _add_cache(self, model_forward_batch) -> None: """ Add cache for guided decoding. """ - if self.guided_backend is None: + if self.guided_backend is None or model_forward_batch is None: return for request in model_forward_batch: @@ -1461,20 +1990,21 @@ def profile_run(self) -> None: # TODO(gongshaotian): Optimize the management logic of kvcache self.num_gpu_blocks = self.cache_config.total_block_num self.initialize_kv_cache(profile=True) + if self.speculative_method in ["mtp"]: + self.proposer.initialize_kv_cache(main_model_num_blocks=self.num_gpu_blocks, profile=True) # 1. Profile with multimodal encoder & encoder cache # 2. Dummy run self._dummy_run( num_tokens=self.scheduler_config.max_num_batched_tokens, - batch_size=min(self.scheduler_config.max_num_seqs, 3), + batch_size=self.scheduler_config.max_num_seqs, ) # 3. gc self.clear_cache() - if self.speculative_method in ["mtp"]: - self.proposer.clear_dummy_input() + self.proposer.clear_mtp_cache() def update_share_input_block_num(self, num_gpu_blocks: int) -> None: """ @@ -1504,7 +2034,7 @@ def update_share_input_block_num(self, num_gpu_blocks: int) -> None: ) if self.speculative_method in ["mtp"]: - self.proposer.update_block_num(num_gpu_blocks) + self.proposer.update_mtp_block_num(num_gpu_blocks) def cal_theortical_kvcache(self): """ @@ -1536,7 +2066,19 @@ def cal_theortical_kvcache(self): if self.speculative_method in ["mtp"] else self.model_config.num_hidden_layers ) - required_memory = byte_of_dtype * 2 * (self.cache_config.block_size * hidden_dim) * num_layers # k + v + + # NOTE:(changwenbin) Determie whether it is Multi-Head Latent Attention, + # To rationalize the allocation of kvcache. + self.mla_cache = envs.FD_ATTENTION_BACKEND == "MLA_ATTN" + if self.mla_cache: + required_memory = ( + byte_of_dtype + * (self.fd_config.model_config.kv_lora_rank + self.fd_config.model_config.qk_rope_head_dim) + * (self.cache_config.block_size) + * num_layers + ) # compress_kv + k_pe + else: + required_memory = byte_of_dtype * 2 * (self.cache_config.block_size * hidden_dim) * num_layers # k + v return required_memory def not_need_stop(self) -> bool: @@ -1548,18 +2090,35 @@ def clear_cache(self): self.share_inputs.pop("caches", None) if self.forward_meta is not None: self.forward_meta.clear_caches() + paddle.device.empty_cache() def clear_parameters(self, pid): - """ " Dynamic model loader use to clear parameters use for RL""" + """Dynamic model loader use to clear parameters use for RL""" + # Clear CUDAGraph + if self.use_cudagraph: + self.model.clear_grpah_opt_backend() + # Clear parameters and Send single self.dynamic_weight_manager.clear_parameters(pid) self.clear_cache() - # paddle.device.empty_cache() + paddle.device.empty_cache() + self.dynamic_weight_manager._log_memory("dynamic weight manager clear all memory") + def clear_requests(self): + """Dynamic model loader use to clear requests use for RL""" + self.share_inputs["stop_flags"][:] = True + def update_parameters(self, pid): - """ " Dynamic model loader use to update parameters use for RL""" + """Dynamic model loader use to update parameters use for RL""" + # Update parameters self.dynamic_weight_manager.update_parameters(pid) self.initialize_kv_cache() + # Recapture CUDAGraph + if self.use_cudagraph: + self.capture_model() + # Send single + self.dynamic_weight_manager.finalize_update(pid) + self.dynamic_weight_manager._log_memory("dynamic weight manager update all memory") def padding_cudagraph_inputs(self) -> None: @@ -1568,6 +2127,11 @@ def padding_cudagraph_inputs(self) -> None: In FastDeploy, almost all input tensors have a buffer. So, just keep the buffer clean when replaying the CUDA graph with the padded batch. """ # In init_attention_metadata, the decode buffer has already been cleared + + # To adapt to CUDA Graph, keep the forward pass at the maximum batch size. + if self.use_cudagraph: + self.forward_meta.seq_lens_this_time = self.seq_lens_this_time_buffer + self.real_token_num = self.forward_meta.ids_remove_padding.shape[0] return def _init_image_preprocess(self) -> None: @@ -1604,7 +2168,7 @@ def _preprocess_mm_task(self, one: dict) -> None: image_type_ids = one["image_type_ids"][np.newaxis, :] images = one["images"] image_type_ids = paddle.to_tensor(image_type_ids, dtype=paddle.int64) - images = paddle.to_tensor(images, dtype="uint8") + images = paddle.to_tensor(images, dtype="uint8" if "ernie" in self.model_config.model_type else "bfloat16") grid_thw = paddle.to_tensor(one["grid_thw"], dtype="int64") else: image_type_ids = None @@ -1626,12 +2190,9 @@ def _preprocess_mm_task(self, one: dict) -> None: ) return result - @paddle.no_grad() - def extract_vision_features(self, inputs: list[paddle.Tensor]) -> paddle.Tensor: - """extract_vision_features""" + def extract_vision_features_ernie(self, inputs: list[paddle.Tensor]) -> paddle.Tensor: assert inputs["images"] is not None grid_thw = inputs["grid_thw"] - images = inputs["images"].cast("float32") images = self.image_preprocess.rescale_factor * images - self.image_preprocess.image_mean_tensor images = images / self.image_preprocess.image_std_tensor @@ -1665,6 +2226,14 @@ def extract_vision_features(self, inputs: list[paddle.Tensor]) -> paddle.Tensor: ) return image_features + @paddle.no_grad() + def extract_vision_features(self, inputs: list[paddle.Tensor]) -> paddle.Tensor: + """extract_vision_features""" + if "ernie" in self.model_config.model_type: + return self.extract_vision_features_ernie(inputs) + else: + raise ValueError(f"multiple modalities model {self.model_config.model_type} is not supported") + @paddle.no_grad() def prepare_rope3d(self, position_ids: paddle.Tensor, max_len: int) -> paddle.Tensor: """prepare_rope3d""" @@ -1684,5 +2253,6 @@ def prepare_rope3d(self, position_ids: paddle.Tensor, max_len: int) -> paddle.Te base=self.model_config.rope_theta, max_position=self.model_config.max_model_len, freq_allocation=getattr(self.model_config, "freq_allocation", 20), + model_type=self.model_config.model_type, ) return rope_emb diff --git a/fastdeploy/worker/metax_worker.py b/fastdeploy/worker/metax_worker.py index c30c067620..675c2a9e0e 100644 --- a/fastdeploy/worker/metax_worker.py +++ b/fastdeploy/worker/metax_worker.py @@ -20,13 +20,12 @@ from typing import List, Optional import paddle -import pymxsml from paddle import nn from fastdeploy import envs from fastdeploy.config import FDConfig from fastdeploy.engine.request import Request -from fastdeploy.utils import get_logger +from fastdeploy.utils import get_logger, set_random_seed from fastdeploy.worker.metax_model_runner import MetaxModelRunner from fastdeploy.worker.output import ModelRunnerOutput from fastdeploy.worker.worker_base import WorkerBase @@ -53,23 +52,21 @@ def init_device(self): Initialize device and construct model runner """ self.max_chips_per_node = 8 - if paddle.is_compiled_with_custom_device("metax_gpu"): - # Set environment variable - self.device_ids = self.parallel_config.device_ids.split(",") - self.device = f"metax_gpu:{self.local_rank % self.max_chips_per_node}" - paddle.device.set_device(self.device) - paddle.set_default_dtype(self.model_config.dtype) + # Set environment variable + self.device_ids = self.parallel_config.device_ids.split(",") + self.device = f"metax_gpu:{self.local_rank % self.max_chips_per_node}" + paddle.device.set_device(self.device) + paddle.set_default_dtype(self.model_config.dtype) - gc.collect() - - else: - raise RuntimeError(f"Not support device type: {self.device_config.device}") + gc.collect() + paddle.device.empty_cache() + set_random_seed(self.fd_config.model_config.seed) # Construct model runner self.model_runner: MetaxModelRunner = MetaxModelRunner( fd_config=self.fd_config, device=self.device, - device_id=self.device_ids[self.local_rank % self.max_chips_per_node], + device_id=int(self.device_ids[self.local_rank % self.max_chips_per_node]), rank=self.rank, local_rank=self.local_rank, ) @@ -99,6 +96,8 @@ def determine_available_memory(self) -> int: if fd_kvache_mem is not None: return int(float(fd_kvache_mem) * 1024**3) else: + import pymxsml + # 1. Record memory state before profile run start_time = time.perf_counter() Gb = 1024**3 @@ -200,9 +199,10 @@ def graph_optimize_and_warm_up_model(self) -> None: """ Perform the warm-up and the graph optimization """ - if self.model_runner.graph_opt_level >= 1: + if self.fd_config.graph_opt_config.graph_opt_level >= 1 and not self.model_runner.use_cudagraph: self.model_runner.sot_warmup() - # Todo Trigger cuda graph capture. + # Trigger cuda graph capture + self.model_runner.capture_model() def check_health(self) -> bool: """ """