diff --git a/CMakeLists.txt b/CMakeLists.txt index 3f1f9a781a07..72a88c633370 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -924,7 +924,8 @@ if(VLLM_GPU_LANG STREQUAL "HIP") set(VLLM_ROCM_EXT_SRC "csrc/rocm/torch_bindings.cpp" "csrc/rocm/skinny_gemms.cu" - "csrc/rocm/attention.cu") + "csrc/rocm/attention.cu" + "csrc/rocm/rocsolgemm.cu") define_gpu_extension_target( _rocm_C diff --git a/README.md b/README.md index f9e40acc7d5f..bc849fa3236a 100644 --- a/README.md +++ b/README.md @@ -17,6 +17,7 @@ Easy, fast, and cheap LLM serving for everyone --- # Deprecation warning + > [!CAUTION] > The ROCm/vllm repository is retired, please use the [upstream](https://github.com/vllm-project/vllm.git) repository > diff --git a/csrc/rocm/ops.h b/csrc/rocm/ops.h index 34dcc9401aae..9d840323d61e 100644 --- a/csrc/rocm/ops.h +++ b/csrc/rocm/ops.h @@ -11,6 +11,9 @@ torch::Tensor wvSplitK(at::Tensor& in_a, at::Tensor& in_b, void wvSplitKQ(at::Tensor& in_a, at::Tensor& in_b, at::Tensor& out_c, at::Tensor& scale_a, at::Tensor& scale_b, const int64_t CuCount); +torch::Tensor RocSolIdxBlas(torch::Tensor& mat1, torch::Tensor& mat2, + int64_t solution_index); + void paged_attention( torch::Tensor& out, torch::Tensor& exp_sums, torch::Tensor& max_logits, torch::Tensor& tmp_out, torch::Tensor& query, torch::Tensor& key_cache, diff --git a/csrc/rocm/rocsolgemm.cu b/csrc/rocm/rocsolgemm.cu new file mode 100644 index 000000000000..5bdfea26f9d4 --- /dev/null +++ b/csrc/rocm/rocsolgemm.cu @@ -0,0 +1,223 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. +// #ifdef __gfx908__ +// // Uncomment ifdef and endif only if you need to undef the HIP_HALF ops below +// just for gfx908 and not for others +// // below lines enable hip float to half conversion which are disabled by +// default in hip_fp16.h #undef __HIP_NO_HALF_OPERATORS__ #undef +// __HIP_NO_HALF_CONVERSIONS__ #endif + +#include + +// #include +// #include +// #include +#include +#include +#include +#include +#include +#include +#include "core/registration.h" + +#include + +#ifndef CHECK_HIP_ERROR + #define CHECK_HIP_ERROR(error) \ + if (error != hipSuccess) { \ + fprintf(stderr, "Hip error: '%s'(%d) at %s:%d\n", \ + hipGetErrorString(error), error, __FILE__, __LINE__); \ + exit(EXIT_FAILURE); \ + } +#endif + +#ifndef CHECK_ROCBLAS_STATUS + #define CHECK_ROCBLAS_STATUS(status) \ + if (status != rocblas_status_success) { \ + fprintf(stderr, "rocBLAS error: '%s'(%d) at %s:%d\n", \ + rocblas_status_to_string(status), status, __FILE__, __LINE__); \ + exit(EXIT_FAILURE); \ + } +#endif + +// namespace +// { +// rocblas_handle r_handle; +// rocb_create_extension(); +// // /*thread_local*/ cudaStream_t weight_stream; +// // BUG: DLM has event and stream on different devices error +// // In multi-GPU scenerio, do names defined in this namespace exist on all +// devices? +// // C++ keyword: thread_local <- maybe this can help? +// // /*thread_local*/ cudaEvent_t event; + +// // // hipBLASLt +// // hipblasLtHandle_t hipblaslt_handle; +// // hipblasLtMatmulPreference_t preference; +// // uint64_t workspace_size = 32 * 1024 * 1024; +// // // uint64_t workspace_size = 0; +// // void *d_workspace; +// // int request_solutions = 1; +// // int returnedAlgoCount = 0; + +// struct MatMulConfig +// { +// hipblasOperation_t op_A; +// hipblasOperation_t op_B; +// int M; +// int N; +// int K; +// hipblasDiagType_t dtype; + +// friend auto operator<(const MatMulConfig &left, const MatMulConfig +// &right) -> bool +// { +// return std::tie(left.op_A, left.op_B, left.M, left.N, left.K, +// left.dtype) < std::tie(right.op_A, right.op_B, right.M, right.N, +// right.K, right.dtype); +// } +// }; + +// // std::map, +// std::vector> heuristic_map; +// std::map heuristic_map; + +// // hipEvent_t start, stop; +// // int bench_iters{1}; +// // int warmup_iters{1}; + +// // bool cout_print = true; +// } + +rocblas_handle r_handle; + +torch::Tensor RocSolIdxBlas(torch::Tensor& mat1, torch::Tensor& mat2, + int64_t solution_index) { + auto mat1_strides{mat1.strides()}; + auto mat2_strides{mat2.strides()}; + auto mat1_sizes{mat1.sizes()}; + auto mat2_sizes{mat2.sizes()}; + // std::cout << " | mat1 info: size: " << mat1_sizes << " stride: " << + // mat1_strides << std::endl + // << " | mat2 info: size: " << mat2_sizes << " stride: " << + // mat2_strides << std::endl; + + TORCH_CHECK(mat1.dim() == 2 && mat2.dim() == 2, "tensors must be 2-D"); + TORCH_CHECK(mat1.dtype() == mat2.dtype(), + "expected mat1 and mat2 to have the same dtype, but got: ", + mat1.dtype(), " != ", mat2.dtype()); + TORCH_CHECK(mat1_sizes[1] == mat2_sizes[0], + "mat1 dim 1 must match mat2 dim 0"); + + auto abcType{mat1.options().dtype()}; + auto options{at::TensorOptions().dtype(abcType).device(at::kCUDA)}; + auto result{torch::empty({mat1_sizes[0], mat2_sizes[1]}, options)}; + // std::cout << " | result info: size: " << result.sizes() << " stride: " << + // result.strides() << std::endl; + + bool transpose_result = true; + bool transpose_mat1; + bool transpose_mat2; + if ((mat2_strides[0] == 1) && + (mat2_strides[1] >= std::max(1, mat2_sizes[0]))) { + transpose_mat2 = false; + } else if ((mat2_strides[1] == 1) && + (mat2_strides[0] >= std::max(1, mat2_sizes[1]))) { + transpose_mat2 = true; + } else { + assert(false && + "unusual strides detected, may need to clone a contiguous tensor"); + } + if ((mat1_strides[0] == 1) && + (mat1_strides[1] >= std::max(1, mat1_sizes[0]))) { + transpose_mat1 = false; + } else if ((mat1_strides[1] == 1) && + (mat1_strides[0] >= std::max(1, mat1_sizes[1]))) { + transpose_mat1 = true; + } else { + assert(false && + "unusual strides detected, may need to clone a contiguous tensor"); + } + + if (transpose_result) { + bool tmp = transpose_mat1; + transpose_mat1 = !transpose_mat2; + transpose_mat2 = !tmp; + mat1_strides = mat2.strides(); + mat2_strides = mat1.strides(); + mat1_sizes = mat2.sizes(); + mat2_sizes = mat1.sizes(); + } + // std::cout << " | transpose_result: " << (transpose_result ? "true" : + // "false") << std::endl + // << " | transpose_A: " << (transpose_mat1 ? "true" : "false") << + // std::endl + // << " | transpose_B: " << (transpose_mat2 ? "true" : "false") << + // std::endl; + // std::cout << " | A matrix: size: " << mat1_sizes << " stride: " << + // mat1_strides << std::endl + // << " | B matrix: size: " << mat2_sizes << " stride: " << + // mat2_strides << std::endl; + + float one{1.0f}; + float zero{0.0f}; + int64_t m = mat1_sizes[transpose_result ? 1 : 0]; + int64_t k = mat1_sizes[transpose_result ? 0 : 1]; + int64_t n = mat2_sizes[transpose_result ? 0 : 1]; + int64_t mat1_ld = mat1_strides[(transpose_mat1 == transpose_result) ? 1 : 0]; + int64_t mat2_ld = mat2_strides[(transpose_mat2 == transpose_result) ? 1 : 0]; + int64_t result_ld = result.stride(transpose_result ? 0 : 1); + // std::cout << " | (m, n, k): " << m << ", " << n << ", " << k << std::endl + // << " | (lda, ldb, ldc): " << mat1_ld << ", " << mat2_ld << ", " + // << result_ld << std::endl; + + // std::cout << "Input mat1: size: " << mat1.sizes() << ", stride: " << + // mat1.strides() << std::endl; std::cout << "Input mat2: size: " << + // mat2.sizes() << ", stride: " << mat2.strides() << std::endl; std::cout << + // "mat1 data: "; std::cout << mat1 << std::endl; // 打印 mat1 的数据 + // std::cout << "mat2 data: "; + // std::cout << mat2 << std::endl; // 打印 mat2 的数据 + + void* ptrA{static_cast((transpose_result ? mat2 : mat1).data_ptr())}; + void* ptrB{static_cast((transpose_result ? mat1 : mat2).data_ptr())}; + void* ptrC{static_cast(result.data_ptr())}; + auto current_stream{torch::hip::getCurrentHIPStream().stream()}; + if (r_handle == nullptr) rocblas_create_handle(&r_handle); + rocblas_set_stream(r_handle, current_stream); + uint32_t flags{0}; + // int32_t solution_index {0}; + rocblas_datatype abcRtype; + if (abcType == at::kHalf) { + abcRtype = rocblas_datatype_f16_r; + } else if (abcType == at::kBFloat16) { + abcRtype = rocblas_datatype_bf16_r; + } else if (abcType == at::kFloat) { + abcRtype = rocblas_datatype_f32_r; + } else { + assert(false && "Wrong datatype!"); + } + + // CHECK_ROCBLAS_ERROR( + rocblas_status rstatus = rocblas_gemm_ex( + r_handle, + transpose_mat1 ? rocblas_operation_transpose : rocblas_operation_none, + transpose_mat2 ? rocblas_operation_transpose : rocblas_operation_none, m, + n, k, &one, ptrA, abcRtype, mat1_ld, ptrB, abcRtype, mat2_ld, &zero, ptrC, + abcRtype, result_ld, ptrC, abcRtype, result_ld, rocblas_datatype_f32_r, + rocblas_gemm_algo_solution_index, solution_index, flags); + //); + CHECK_ROCBLAS_STATUS(rstatus); + return result; +} + +///////////////////////////////////////////////////////////////////////////////////////////////////////// + +void rocb_create_extension() { rocblas_create_handle(&r_handle); } + +///////////////////////////////////////////////////////////////////////////////////////////////////////// + +void rocb_destroy_extension() { rocblas_destroy_handle(r_handle); } + +TORCH_LIBRARY_IMPL_EXPAND(TORCH_EXTENSION_NAME, CUDA, m) { + m.impl("RocSolIdxBlas", &RocSolIdxBlas); +} \ No newline at end of file diff --git a/csrc/rocm/torch_bindings.cpp b/csrc/rocm/torch_bindings.cpp index 66bdc448da3c..31aa56fe39ea 100644 --- a/csrc/rocm/torch_bindings.cpp +++ b/csrc/rocm/torch_bindings.cpp @@ -32,6 +32,11 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, rocm_ops) { " Tensor scale_b, int CuCount) -> ()"); rocm_ops.impl("wvSplitKQ", torch::kCUDA, &wvSplitKQ); + rocm_ops.def( + "RocSolIdxBlas(Tensor mat1, Tensor mat2, int solution_index) -> " + "Tensor"); + rocm_ops.impl("RocSolIdxBlas", torch::kCUDA, &RocSolIdxBlas); + // Custom attention op // Compute the attention between an input query and the cached // keys/values using PagedAttention. diff --git a/vllm/_custom_ops.py b/vllm/_custom_ops.py index 6e9a8df0a56a..aac63ec1b672 100644 --- a/vllm/_custom_ops.py +++ b/vllm/_custom_ops.py @@ -1462,6 +1462,21 @@ def wvSplitKQ(a: torch.Tensor, b: torch.Tensor, out_dtype: torch.dtype, return out +if hasattr(torch.ops._rocm_C, "RocSolIdxBlas"): + + @register_fake("_rocm_C::RocSolIdxBlas") + def _RocSolIdxBlas_fake(a: torch.Tensor, b: torch.Tensor, + solution_index: int) -> torch.Tensor: + return torch.empty((a.size(0), b.size(1)), + dtype=a.dtype, + device=a.device) + + +def RocSolIdxBlas(a: torch.Tensor, b: torch.Tensor, + solution_index: int) -> torch.Tensor: + return torch.ops._rocm_C.RocSolIdxBlas(a, b, solution_index) + + # moe def moe_sum(input: torch.Tensor, output: torch.Tensor): torch.ops._moe_C.moe_sum(input, output) diff --git a/vllm/model_executor/layers/utils.py b/vllm/model_executor/layers/utils.py index d2b135c1e4d4..a2f22ed8fd1b 100644 --- a/vllm/model_executor/layers/utils.py +++ b/vllm/model_executor/layers/utils.py @@ -1,6 +1,7 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project """Utility methods for model layers.""" +import os from typing import Callable, Optional import torch @@ -10,6 +11,9 @@ from vllm.platforms import current_platform from vllm.utils import direct_register_custom_op +VLLM_USE_ROCB_GEMM = (os.getenv("VLLM_USE_ROCB_GEMM", "False").lower() + in ("true", "1")) + def shuffle_weight(w: torch.Tensor) -> torch.Tensor: # Shuffle weight along the last dimension so that @@ -98,16 +102,23 @@ def rocm_unquantized_gemm_impl( bias: Optional[torch.Tensor] = None) -> torch.Tensor: from vllm.platforms.rocm import on_gfx9 k = weight.shape[1] + x_view = x.view(-1, x.size(-1)) + n = x_view.shape[0] + m = weight.shape[0] use_skinny = (envs.VLLM_ROCM_USE_SKINNY_GEMM and on_gfx9() and \ x.dtype in [torch.float16, torch.bfloat16] \ and k % 8 == 0 and bias is None) + if VLLM_USE_ROCB_GEMM: + solidx = 0 + out = ops.RocSolIdxBlas(x_view, weight.t(), solidx) + if bias is not None: + out = out + bias + return out + if use_skinny is not True: return torch.nn.functional.linear(x, weight, bias) - x_view = x.view(-1, x.size(-1)) - n = x_view.shape[0] - m = weight.shape[0] cu_count = current_platform.get_cu_count() if m > 8 and 0 < n <= 4: