Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,9 @@
#include "cute/tensor.hpp"
#include "helper.h"
#include "paddle/extension.h"
#ifndef PADDLE_WITH_CUSTOM_DEVICE_METAX_GPU
#include "paddle/phi/core/memory/memcpy.h"
#endif
#include "utils.cuh"

template <int THREADBLOCK_SIZE>
Expand Down
66 changes: 66 additions & 0 deletions custom_ops/gpu_ops/append_attn/mem_util.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@

#include <cuda_runtime.h>
#include <stdint.h>
#include <cooperative_groups/memcpy_async.h>

enum class SharedMemFillMode { kFillZero, kNoFill };

Expand Down Expand Up @@ -42,18 +43,35 @@ __device__ __forceinline__ void ldmatrix_m8n8x4_trans_impl(uint32_t* R,
}

__device__ __forceinline__ void commit_group() {
#ifdef PADDLE_WITH_CUSTOM_DEVICE_METAX_GPU
{}
#else
asm volatile("cp.async.commit_group;\n" ::);
#endif
}

template <size_t n>
__device__ __forceinline__ void wait_group() {
#ifdef PADDLE_WITH_CUSTOM_DEVICE_METAX_GPU
cooperative_groups::wait(cooperative_groups::this_thread_block());
#else
asm volatile("cp.async.wait_group %0;\n" ::"n"(n));
#endif
}

template <PrefetchMode prefetch_mode, typename T>
__device__ __forceinline__ void load_128b(T* smem_ptr, const T* gmem_ptr) {
uint32_t smem_int_ptr =
static_cast<uint32_t>(__cvta_generic_to_shared(smem_ptr));
#ifdef PADDLE_WITH_CUSTOM_DEVICE_METAX_GPU
if constexpr (prefetch_mode == PrefetchMode::kPrefetch) {
memset(__cvta_shared_to_generic(smem_int_ptr), 0, 16);
memcpy(__cvta_shared_to_generic(smem_int_ptr), (void *)gmem_ptr, 16);
} else {
memset(__cvta_shared_to_generic(smem_int_ptr), 0, 16);
memcpy(__cvta_shared_to_generic(smem_int_ptr), (void *)gmem_ptr, 16);
}
#else
if constexpr (prefetch_mode == PrefetchMode::kPrefetch) {
asm volatile(
"cp.async.cg.shared.global.L2::128B [%0], [%1], %2, %3;\n" ::"r"(
Expand All @@ -68,6 +86,7 @@ __device__ __forceinline__ void load_128b(T* smem_ptr, const T* gmem_ptr) {
"n"(16),
"r"(16));
}
#endif
}

template <PrefetchMode prefetch_mode, SharedMemFillMode fill_mode, typename T>
Expand All @@ -76,6 +95,28 @@ __device__ __forceinline__ void pred_load_128b(T* smem_ptr,
bool predicate) {
uint32_t smem_int_ptr =
static_cast<uint32_t>(__cvta_generic_to_shared(smem_ptr));
#ifdef PADDLE_WITH_CUSTOM_DEVICE_METAX_GPU
if constexpr (fill_mode == SharedMemFillMode::kFillZero) {
int src_in_bytes = predicate ? 16 : 0;
if constexpr (prefetch_mode == PrefetchMode::kPrefetch) {
memset(__cvta_shared_to_generic(smem_int_ptr), 0, 16);
memcpy(__cvta_shared_to_generic(smem_int_ptr), (void *)gmem_ptr, src_in_bytes);
} else {
memset(__cvta_shared_to_generic(smem_int_ptr), 0, 16);
memcpy(__cvta_shared_to_generic(smem_int_ptr), (void *)gmem_ptr, src_in_bytes);
}
} else {
if constexpr (prefetch_mode == PrefetchMode::kPrefetch) {
if (predicate) {
memcpy(__cvta_shared_to_generic(smem_int_ptr), (void *)gmem_ptr, 16);
}
} else {
if (predicate) {
memcpy(__cvta_shared_to_generic(smem_int_ptr), (void *)gmem_ptr, 16);
}
}
}
#else
if constexpr (fill_mode == SharedMemFillMode::kFillZero) {
int src_in_bytes = predicate ? 16 : 0;
if constexpr (prefetch_mode == PrefetchMode::kPrefetch) {
Expand Down Expand Up @@ -115,6 +156,7 @@ __device__ __forceinline__ void pred_load_128b(T* smem_ptr,
"n"(16));
}
}
#endif
}

template <PrefetchMode prefetch_mode, SharedMemFillMode fill_mode, typename T>
Expand All @@ -123,6 +165,17 @@ __device__ __forceinline__ void pred_load_64b(T* smem_ptr,
bool predicate) {
uint32_t smem_int_ptr =
static_cast<uint32_t>(__cvta_generic_to_shared(smem_ptr));
#ifdef PADDLE_WITH_CUSTOM_DEVICE_METAX_GPU
if constexpr (fill_mode == SharedMemFillMode::kFillZero) {
int src_in_bytes = predicate ? 8 : 0;
memset(__cvta_shared_to_generic(smem_int_ptr), 0, 8);
memcpy(__cvta_shared_to_generic(smem_int_ptr), (void *)gmem_ptr, src_in_bytes);
} else {
if (predicate) {
memcpy(__cvta_shared_to_generic(smem_int_ptr), (void *)gmem_ptr, 8);
}
}
#else
if constexpr (fill_mode == SharedMemFillMode::kFillZero) {
int src_in_bytes = predicate ? 8 : 0;
asm volatile(
Expand All @@ -141,6 +194,7 @@ __device__ __forceinline__ void pred_load_64b(T* smem_ptr,
"l"(gmem_ptr),
"n"(8));
}
#endif
}

template <PrefetchMode prefetch_mode, SharedMemFillMode fill_mode, typename T>
Expand All @@ -149,6 +203,17 @@ __device__ __forceinline__ void pred_load_32b(T* smem_ptr,
bool predicate) {
uint32_t smem_int_ptr =
static_cast<uint32_t>(__cvta_generic_to_shared(smem_ptr));
#ifdef PADDLE_WITH_CUSTOM_DEVICE_METAX_GPU
if constexpr (fill_mode == SharedMemFillMode::kFillZero) {
int src_in_bytes = predicate ? 4 : 0;
memset(__cvta_shared_to_generic(smem_int_ptr), 0, 4);
memcpy(__cvta_shared_to_generic(smem_int_ptr), (void *)gmem_ptr, src_in_bytes);
} else {
if (predicate) {
memcpy(__cvta_shared_to_generic(smem_int_ptr), (void *)gmem_ptr, 4);
}
}
#else
if constexpr (fill_mode == SharedMemFillMode::kFillZero) {
int src_in_bytes = predicate ? 4 : 0;
asm volatile(
Expand All @@ -167,6 +232,7 @@ __device__ __forceinline__ void pred_load_32b(T* smem_ptr,
"l"(gmem_ptr),
"n"(4));
}
#endif
}

template <size_t num_bits, PrefetchMode prefetch_mode, typename T>
Expand Down
5 changes: 4 additions & 1 deletion custom_ops/gpu_ops/helper.h
Original file line number Diff line number Diff line change
Expand Up @@ -592,10 +592,13 @@ inline int get_cuda_max_shared_memory_per_block_opt_in(int const device) {
#endif

inline int GetSMVersion() {
#ifdef PADDLE_WITH_CUSTOM_DEVICE_METAX_GPU
return 80;
#else
static int sm_version = phi::backends::gpu::GetGPUComputeCapability(
phi::backends::gpu::GetCurrentDeviceId());
return sm_version;

#endif
}

inline bool GetMlaUseTensorcore() {
Expand Down
16 changes: 15 additions & 1 deletion custom_ops/gpu_ops/noauxtc_kernel.h
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
#include <cooperative_groups.h>
#include <cooperative_groups/reduce.h>
#include "helper.h"
#include <cuda/std/limits>

namespace cg = cooperative_groups;

Expand Down Expand Up @@ -601,7 +602,7 @@ __global__ void group_idx_and_topk_idx_kernel(
if (i < topk) {
s_topk_value[i] = value;
}
topk_sum += reduce(tile, cuda_cast<float, T>(value), cg::plus<float>());
topk_sum += cg::reduce(tile, cuda_cast<float, T>(value), cg::plus<float>());
}
}

Expand Down Expand Up @@ -658,6 +659,11 @@ void invokeNoAuxTc(T* scores,
cudaStream_t const stream) {
int64_t num_cases = num_tokens * n_group;
int64_t topk_with_k2_num_blocks = (num_cases - 1) / NUM_WARPS_PER_BLOCK + 1;

#ifdef PADDLE_WITH_CUSTOM_DEVICE_METAX_GPU
topk_with_k2_kernel<T><<<topk_with_k2_num_blocks, BLOCK_SIZE, 0, stream>>>(
group_scores, scores_with_bias, num_tokens, num_cases, n_group, num_experts / n_group);
#else
auto* kernel_instance1 = &topk_with_k2_kernel<T>;
cudaLaunchConfig_t config;
config.gridDim = topk_with_k2_num_blocks;
Expand All @@ -671,13 +677,20 @@ void invokeNoAuxTc(T* scores,
config.attrs = attrs;
cudaLaunchKernelEx(&config, kernel_instance1, group_scores, scores_with_bias,
num_tokens, num_cases, n_group, num_experts / n_group);
#endif

int64_t topk_with_k_group_num_blocks =
(num_tokens - 1) / NUM_WARPS_PER_BLOCK + 1;
size_t dynamic_smem_in_bytes =
warp_topk::calc_smem_size_for_block_wide<T, int32_t>(NUM_WARPS_PER_BLOCK,
topk);

#ifdef PADDLE_WITH_CUSTOM_DEVICE_METAX_GPU
group_idx_and_topk_idx_kernel<T, IdxT><<<topk_with_k_group_num_blocks, BLOCK_SIZE, dynamic_smem_in_bytes, stream>>>(
scores, group_scores, topk_values, topk_indices, scores_with_bias,
num_tokens, n_group, topk_group, topk, num_experts, num_experts / n_group,
renormalize, routed_scaling_factor);
#else
auto* kernel_instance2 = &group_idx_and_topk_idx_kernel<T, IdxT>;
config.gridDim = topk_with_k_group_num_blocks;
config.blockDim = BLOCK_SIZE;
Expand All @@ -691,6 +704,7 @@ void invokeNoAuxTc(T* scores,
topk_values, topk_indices, scores_with_bias, num_tokens,
n_group, topk_group, topk, num_experts,
num_experts / n_group, renormalize, routed_scaling_factor);
#endif
}

#define INSTANTIATE_NOAUX_TC(T, IdxT) \
Expand Down
9 changes: 8 additions & 1 deletion custom_ops/setup_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -600,9 +600,16 @@ def find_end_files(directory, end_str):
"gpu_ops/read_data_ipc.cu",
"gpu_ops/dequant_int8.cu",
"gpu_ops/share_external_data.cu",
"gpu_ops/recover_decode_task.cu",
"gpu_ops/noaux_tc.cu",
"gpu_ops/fused_rotary_position_encoding.cu",
"gpu_ops/text_image_gather_scatter.cu",
"gpu_ops/text_image_index_out.cu",
"gpu_ops/get_position_ids_and_mask_encoder_batch.cu",
"gpu_ops/append_attn/mla_cache_kernel.cu",
"gpu_ops/append_attn/get_block_shape_and_split_kv_block.cu",
"gpu_ops/moe/tritonmoe_preprocess.cu",
"gpu_ops/moe/moe_topk_select.cu",
"gpu_ops/recover_decode_task.cu",
"metax_ops/moe_dispatch.cu",
"metax_ops/moe_ffn.cu",
"metax_ops/moe_reduce.cu",
Expand Down
2 changes: 1 addition & 1 deletion fastdeploy/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -1160,7 +1160,7 @@ def __init__(self, args):
self.kv_cache_ratio = 1.0
else:
self.kv_cache_ratio = 0.75
self.enc_dec_block_num = 0 if current_platform.is_maca() else envs.FD_ENC_DEC_BLOCK_NUM
self.enc_dec_block_num = envs.FD_ENC_DEC_BLOCK_NUM
self.prealloc_dec_block_slot_num_threshold = 12
self.cache_dtype = "bfloat16"
self.model_cfg = None
Expand Down
3 changes: 2 additions & 1 deletion fastdeploy/engine/common_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -1006,7 +1006,8 @@ def _exit_sub_services(self):
exit sub services
"""
self.running = False
self.engine_worker_queue_server.cleanup()
if hasattr(self, "engine_worker_queue_server") and self.engine_worker_queue_server is not None:
self.engine_worker_queue_server.cleanup()
self.exist_task_signal.clear()
self.exist_swapped_task_signal.clear()
self.worker_healthy_live_signal.clear()
Expand Down
2 changes: 2 additions & 0 deletions fastdeploy/model_executor/layers/backends/metax/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,11 +13,13 @@
# limitations under the License.

from .attention.flash_attn_backend import FlashAttentionBackend
from .attention.mla_attn_metax_backend import MetaxMLAAttentionBackend
from .moe.fused_moe_cutlass_metax_backend import MetaxCutlassWeightOnlyMoEMethod
from .moe.fused_moe_triton_metax_backend import MetaxTritonWeightOnlyMoEMethod

__all__ = [
"FlashAttentionBackend",
"MetaxMLAAttentionBackend",
"MetaxTritonWeightOnlyMoEMethod",
"MetaxCutlassWeightOnlyMoEMethod",
]
Loading
Loading