Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
65 changes: 65 additions & 0 deletions kernels/csrc/attention/attention_generic.cuh
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
/*
* Adapted from
* https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/kernels/decoder_masked_multihead_attention_utils.h
* Copyright (c) 2023, The vLLM team.
* Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#pragma once

#include <stdint.h>

namespace vllm {

// A vector type to store Q, K, V elements.
template <typename T, int VEC_SIZE>
struct Vec {};

// A vector type to store FP32 accumulators.
template <typename T>
struct FloatVec {};

// Template vector operations.
template <typename Acc, typename A, typename B>
inline __device__ Acc mul(A a, B b);

template <typename T>
inline __device__ float sum(T v);

template <typename T>
inline __device__ float dot(T a, T b) {
return sum(mul<T, T, T>(a, b));
}

template <typename A, typename T>
inline __device__ float dot(T a, T b) {
return sum(mul<A, T, T>(a, b));
}

template <typename T>
inline __device__ void zero(T& dst) {
constexpr int WORDS = sizeof(T) / 4;
union {
T raw;
uint32_t words[WORDS];
} tmp;

#pragma unroll
for (int ii = 0; ii < WORDS; ++ii) {
tmp.words[ii] = 0u;
}
dst = tmp.raw;
}

} // namespace vllm
41 changes: 41 additions & 0 deletions kernels/csrc/attention/dtype_fp8.cuh
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
#pragma once

#include "attention_generic.cuh"

#include <stdint.h>
#ifdef ENABLE_FP8
#ifndef USE_ROCM
#include <cuda_fp8.h>
#endif // USE_ROCM
#endif // ENABLE_FP8

namespace vllm {

enum class Fp8KVCacheDataType {
kAuto = 0,
kFp8E4M3 = 1,
kFp8E5M2 = 2,
};

// fp8 vector types for quantization of kv cache
template <>
struct Vec<uint8_t, 1> {
using Type = uint8_t;
};

template <>
struct Vec<uint8_t, 2> {
using Type = uint16_t;
};

template <>
struct Vec<uint8_t, 4> {
using Type = uint32_t;
};

template <>
struct Vec<uint8_t, 8> {
using Type = uint2;
};

} // namespace vllm
49 changes: 27 additions & 22 deletions kernels/csrc/fused_kernels.cu
Original file line number Diff line number Diff line change
Expand Up @@ -27,25 +27,25 @@ __global__ void dequant_add_residual_kernel(const int32_t *__restrict__ input,
for (int i = tid; i < hidden_size; i += blockDim.x) {
if constexpr (use_per_token_dequant) {
output[token_idx * hidden_size + i] =
(T)((((float)input[token_idx * hidden_size + i]) * __half2float(scale[token_idx])) +
(T)((((float)input[token_idx * hidden_size + i]) * to_float(scale[token_idx])) +
(float)residual[token_idx * hidden_size + i]);
} else {
output[token_idx * hidden_size + i] =
(T)((((float)input[token_idx * hidden_size + i]) * __half2float(scale)) +
(T)((((float)input[token_idx * hidden_size + i]) * to_float(scale)) +
(float)residual[token_idx * hidden_size + i]);
}
}
}

template <typename T>
__global__ void dequant_kernel(const int32_t *__restrict__ input,
T *__restrict__ output, const at::Half scale, int m,
T *__restrict__ output, half scale, int m,
int hidden_size, int input_stride, int out_stride) {
const int tid = threadIdx.x;
const int token_idx = blockIdx.x;
for (int i = tid; i < hidden_size; i += blockDim.x) {
output[token_idx * out_stride + i] =
(T)(((float)input[token_idx * input_stride + i]) * __half2float(scale));
(T)(((float)input[token_idx * input_stride + i]) * to_float(scale));
}
}

Expand All @@ -71,7 +71,7 @@ __global__ void quant_kernel(const T *__restrict__ input,
const float block_amax_val = blockReduceMax(amax_val);
if (tid == 0) {
s_amax = block_amax_val;
scale[token_idx] = __float2half_rn(block_amax_val / 127.0f);
scale[token_idx] = from_float<T>(block_amax_val / 127.0f);
}
__syncthreads();

Expand All @@ -83,18 +83,19 @@ __global__ void quant_kernel(const T *__restrict__ input,
} else {
for (int i = tid; i < hidden_size; i += blockDim.x) {
output[token_idx * hidden_size + i] =
float_to_int8_rn(((float)input[token_idx * hidden_size + i]) / __half2float(scale));
float_to_int8_rn(((float)input[token_idx * hidden_size + i]) / to_float(scale));
}
}
}


template <typename T, typename scale_type, bool use_per_token_quant>
__global__ void quant_kernel_fuse_sum(const T *__restrict__ input,
int8_t *__restrict__ output,
scale_type input_sum,
scale_type scale,
int num_tokens, int hidden_size) {
int8_t *__restrict__ output,
scale_type input_sum,
scale_type scale,
int num_tokens,
int hidden_size) {
// TODO: get the sum here.
const int tid = threadIdx.x;
const int token_idx = blockIdx.x;
Expand All @@ -118,8 +119,8 @@ __global__ void quant_kernel_fuse_sum(const T *__restrict__ input,
const float block_sum_val = blockReduceSum(sum_val);
if (tid == 0) {
s_amax = block_amax_val;
scale[token_idx] = __float2half_rn(block_amax_val / 127.0f);
input_sum[token_idx] = __float2half_rn(block_sum_val);
scale[token_idx] = from_float<T>(block_amax_val / 127.0f);
input_sum[token_idx] = from_float<T>(block_sum_val);
}
__syncthreads();

Expand All @@ -131,7 +132,7 @@ __global__ void quant_kernel_fuse_sum(const T *__restrict__ input,
} else {
for (int i = tid; i < hidden_size; i += blockDim.x) {
output[token_idx_mul_hidden_size + i] =
float_to_int8_rn(((float)input[token_idx_mul_hidden_size + i]) / __half2float(scale));
float_to_int8_rn(((float)input[token_idx_mul_hidden_size + i]) / to_float(scale));
}
}
}
Expand Down Expand Up @@ -205,8 +206,10 @@ void invoke_quant(torch::Tensor &out, // [..., hidden_size]
dim3 block(std::min(hidden_size, 1024));
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
VLLM_DISPATCH_FLOATING_TYPES(input.scalar_type(), "quant_kernel", [&] {
vllm::quant_kernel<scalar_t, at::Half, false><<<grid, block, 0, stream>>>(
input.data_ptr<scalar_t>(), out.data_ptr<int8_t>(), scale, num_tokens, hidden_size);
// It look like this function is never called. We hard code the template to half since input argument scale is at::Half.
// using T = typename FloatTypeConverter<scalar_t>::Type;
vllm::quant_kernel<half, half, false><<<grid, block, 0, stream>>>(
reinterpret_cast<half*>(input.data_ptr<scalar_t>()), out.data_ptr<int8_t>(), scale, num_tokens, hidden_size);
});
}

Expand All @@ -221,9 +224,10 @@ void invoke_quant(torch::Tensor &out, // [..., hidden_size]
dim3 block(std::min(hidden_size, 1024));
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
VLLM_DISPATCH_FLOATING_TYPES(input.scalar_type(), "quant_kernel", [&] {
vllm::quant_kernel<scalar_t, at::Half *, true><<<grid, block, 0, stream>>>(
input.data_ptr<scalar_t>(), out.data_ptr<int8_t>(),
scale.data_ptr<at::Half>(), num_tokens, hidden_size);
using T = typename FloatTypeConverter<scalar_t>::Type;
vllm::quant_kernel<T, T *, true><<<grid, block, 0, stream>>>(
reinterpret_cast<T*>(input.data_ptr<scalar_t>()), out.data_ptr<int8_t>(),
reinterpret_cast<T*>(scale.data_ptr<scalar_t>()), num_tokens, hidden_size);
});
}

Expand Down Expand Up @@ -259,8 +263,9 @@ void invoke_quant_fuse_sum(torch::Tensor &out, // [..., hidden_size]
dim3 block(std::min(hidden_size, 1024));
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
VLLM_DISPATCH_FLOATING_TYPES(input.scalar_type(), "quant_kernel_fuse_sum", [&] {
vllm::quant_kernel_fuse_sum<scalar_t, at::Half *, true><<<grid, block, 0, stream>>>(
input.data_ptr<scalar_t>(), out.data_ptr<int8_t>(), input_sum.data_ptr<at::Half>(),
scale.data_ptr<at::Half>(), num_tokens, hidden_size);
using T = typename FloatTypeConverter<scalar_t>::Type;
vllm::quant_kernel_fuse_sum<T, T*, true><<<grid, block, 0, stream>>>(
reinterpret_cast<T*>(input.data_ptr<scalar_t>()), out.data_ptr<int8_t>(), reinterpret_cast<T*>(input_sum.data_ptr<scalar_t>()),
reinterpret_cast<T*>(scale.data_ptr<scalar_t>()), num_tokens, hidden_size);
});
}
}
33 changes: 19 additions & 14 deletions kernels/csrc/layernorm_kernels.cu
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
#include <torch/extension.h>

#include "dispatch_utils.h"

#include "utils.cuh"
#include "reduction_utils.cuh"

Expand Down Expand Up @@ -134,7 +135,7 @@ __global__ void generalLayerNorm(const T* input, const T* gamma, const T* beta,
const bool with_per_token_scaling = scale_orig_quant_per_token != nullptr;
const bool with_per_tensor_scaling = scale_orig_quant_per_tensor != nullptr;
const float_packed_t scale_orig_quant
= cuda_cast<float_packed_t>(with_per_tensor_scaling ? __half2float(*scale_orig_quant_per_tensor) : 0.0f);
= cuda_cast<float_packed_t>(with_per_tensor_scaling ? to_float(*scale_orig_quant_per_tensor) : 0.0f);
T_scalar amax = 1e-6f;

for (int i = tidx; i < n_elems; i += blockDim.x)
Expand Down Expand Up @@ -270,9 +271,11 @@ __global__ void generalLayerNorm_fuse_sum(const T* input, const T* gamma, const
const bool with_per_token_scaling = scale_orig_quant_per_token != nullptr;
const bool with_per_tensor_scaling = scale_orig_quant_per_tensor != nullptr;
const float_packed_t scale_orig_quant
= cuda_cast<float_packed_t>(with_per_tensor_scaling ? __half2float(*scale_orig_quant_per_tensor) : 0.0f);
= cuda_cast<float_packed_t>(with_per_tensor_scaling ? to_float(*scale_orig_quant_per_tensor) : 0.0f);
T_scalar amax = 1e-6f;
T_scalar sum = 0.0f;
// It is better to use float for sum.
// T_scalar sum = 0.0f;
float sum = 0.0f;

for (int i = tidx; i < n_elems; i += blockDim.x)
{
Expand Down Expand Up @@ -303,7 +306,7 @@ __global__ void generalLayerNorm_fuse_sum(const T* input, const T* gamma, const
if (with_per_token_scaling)
{
float abs_max_f = blockAllReduceMax(cuda_cast<float>(amax));
float sum_f = blockAllReduceSum(cuda_cast<float>(sum));
float sum_f = blockAllReduceSum(sum);
const float dynamic_per_token_scale = 127.f / abs_max_f;
for (int i = tidx; i < n_elems; i += blockDim.x)
{
Expand All @@ -316,11 +319,13 @@ __global__ void generalLayerNorm_fuse_sum(const T* input, const T* gamma, const

reinterpret_cast<int8_packed_t*>(normed_output_quant)[index]
= cuda_cast<int8_packed_t>(val_f * cuda_cast<float_packed_t>(dynamic_per_token_scale));
// For debug: output unquantized result.
// reinterpret_cast<T*>(normed_output_quant)[index] = cuda_cast<T>(val_f);
}
if (tidx == 0)
{
scale_orig_quant_per_token[bidx] = abs_max_f / 127.f;
input_sum[bidx] = sum_f;
input_sum[bidx] = from_float<scale_type>(sum_f);
}
}
}
Expand Down Expand Up @@ -441,10 +446,10 @@ void rms_norm_general(torch::Tensor &out, // [..., hidden_size]
using T = typename FloatTypeConverter<scalar_t>::Type;
if (use_per_token_quant) {
// per-token
vllm::generalLayerNorm<T, at::Half><<<grid, block, 0, stream>>>(
vllm::generalLayerNorm<T, T><<<grid, block, 0, stream>>>(
reinterpret_cast<T*>(input.data_ptr<scalar_t>()),
reinterpret_cast<T*>(weight.data_ptr<scalar_t>()), nullptr,
nullptr, epsilon, num_tokens, hidden_size, nullptr, scaling.data_ptr<at::Half>(),
nullptr, epsilon, num_tokens, hidden_size, nullptr, reinterpret_cast<T*>(scaling.data_ptr<scalar_t>()),
out.data_ptr<int8_t>(), false
);
// input, gamma, beta, normed_output, eps, tokens, hidden_dim, per_tensor_scale, per_token_scale
Expand All @@ -453,10 +458,10 @@ void rms_norm_general(torch::Tensor &out, // [..., hidden_size]
// weight.data_ptr<scalar_t>(), epsilon, num_tokens, hidden_size);
} else {
// per-tensor
vllm::generalLayerNorm<T, at::Half><<<grid, block, 0, stream>>>(
vllm::generalLayerNorm<T, T><<<grid, block, 0, stream>>>(
reinterpret_cast<T*>(input.data_ptr<scalar_t>()),
reinterpret_cast<T*>(weight.data_ptr<scalar_t>()), nullptr,
nullptr, epsilon, num_tokens, hidden_size, scaling.data_ptr<at::Half>(), nullptr,
nullptr, epsilon, num_tokens, hidden_size, reinterpret_cast<T*>(scaling.data_ptr<scalar_t>()), nullptr,
out.data_ptr<int8_t>(), false
);
}
Expand All @@ -481,11 +486,11 @@ void rms_norm_general_fuse_sum(torch::Tensor &out, // [..., hidden_size]
using T = typename FloatTypeConverter<scalar_t>::Type;
if (use_per_token_quant) {
// per-token
vllm::generalLayerNorm_fuse_sum<T, at::Half><<<grid, block, 0, stream>>>(
vllm::generalLayerNorm_fuse_sum<T, T><<<grid, block, 0, stream>>>(
reinterpret_cast<T*>(input.data_ptr<scalar_t>()),
reinterpret_cast<T*>(weight.data_ptr<scalar_t>()), nullptr,
nullptr, epsilon, num_tokens, hidden_size, input_sum.data_ptr<at::Half>(), nullptr, scaling.data_ptr<at::Half>(),
out.data_ptr<int8_t>(), false
nullptr, epsilon, num_tokens, hidden_size, reinterpret_cast<T*>(input_sum.data_ptr<scalar_t>()),
nullptr, reinterpret_cast<T*>(scaling.data_ptr<scalar_t>()), out.data_ptr<int8_t>(), false
);
// input, gamma, beta, normed_output, eps, tokens, hidden_dim, per_tensor_scale, per_token_scale
// normed_output_quant, use_shmem
Expand All @@ -497,10 +502,10 @@ void rms_norm_general_fuse_sum(torch::Tensor &out, // [..., hidden_size]
// Not implemented per-tensor input_sum
assert(false);

vllm::generalLayerNorm_fuse_sum<T, at::Half><<<grid, block, 0, stream>>>(
vllm::generalLayerNorm_fuse_sum<T, T><<<grid, block, 0, stream>>>(
reinterpret_cast<T*>(input.data_ptr<scalar_t>()),
reinterpret_cast<T*>(weight.data_ptr<scalar_t>()), nullptr,
nullptr, epsilon, num_tokens, hidden_size, nullptr, scaling.data_ptr<at::Half>(), nullptr,
nullptr, epsilon, num_tokens, hidden_size, nullptr, reinterpret_cast<T*>(scaling.data_ptr<scalar_t>()), nullptr,
out.data_ptr<int8_t>(), false
);
}
Expand Down
Loading