Skip to content
Draft
83 changes: 80 additions & 3 deletions benchmarks/fused_kernels/layernorm_rms_benchmarks.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,9 @@

import vllm._custom_ops as ops
from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.quantization.utils.fp8_utils import (
per_token_group_quant_fp8,
)


@dataclass
Expand All @@ -22,13 +25,15 @@ class bench_params_t:
hidden_size: int
add_residual: bool
dtype: torch.dtype
group_size: list[int]

def description(self):
return (
f"N {self.num_tokens} "
f"x D {self.hidden_size} "
f"x R {self.add_residual} "
f"x DT {self.dtype}"
f"x GS {self.group_size}"
)


Expand All @@ -38,10 +43,11 @@ def get_bench_params() -> list[bench_params_t]:
HIDDEN_SIZES = list(range(1024, 8129, 1024))
ADD_RESIDUAL = [True, False]
DTYPES = [torch.bfloat16, torch.float]
GROUP_SIZES = [[1, 64], [1, 128]]

combinations = product(NUM_TOKENS, HIDDEN_SIZES, ADD_RESIDUAL, DTYPES)
combinations = product(NUM_TOKENS, HIDDEN_SIZES, ADD_RESIDUAL, DTYPES, GROUP_SIZES)
bench_params = list(
map(lambda x: bench_params_t(x[0], x[1], x[2], x[3]), combinations)
map(lambda x: bench_params_t(x[0], x[1], x[2], x[3], x[4]), combinations)
)
return bench_params

Expand All @@ -52,6 +58,7 @@ def unfused_int8_impl(
x: torch.Tensor,
residual: torch.Tensor | None,
quant_dtype: torch.dtype,
group_size: list[int],
):
# Norm
torch_out = None
Expand All @@ -69,6 +76,7 @@ def unfused_fp8_impl(
x: torch.Tensor,
residual: torch.Tensor | None,
quant_dtype: torch.dtype,
group_size: list[int],
):
# Norm
torch_out = None
Expand All @@ -81,23 +89,57 @@ def unfused_fp8_impl(
torch_out, _ = ops.scaled_fp8_quant(torch_out)


def unfused_groupwise_fp8_impl(
rms_norm_layer: RMSNorm,
x: torch.Tensor,
residual: torch.Tensor | None,
quant_dtype: torch.dtype,
group_size: list[int],
):
# Norm
torch_out = None
if residual is None:
torch_out = rms_norm_layer.forward_cuda(x, residual)
else:
torch_out, _ = rms_norm_layer.forward_cuda(x, residual)

# Quant
torch_out, _ = per_token_group_quant_fp8(
torch_out, group_size=group_size[1], use_ue8m0=False
)


def fused_impl(
rms_norm_layer: RMSNorm, # this stores the weights
x: torch.Tensor,
residual: torch.Tensor | None,
quant_dtype: torch.dtype,
group_size: list[int],
):
out, _ = ops.rms_norm_dynamic_per_token_quant(
x, rms_norm_layer.weight, 1e-6, quant_dtype, residual=residual
)


def fused_groupwise_impl(
rms_norm_layer: RMSNorm, # this stores the weights
x: torch.Tensor,
residual: torch.Tensor | None,
quant_dtype: torch.dtype,
group_size: list[int],
):
out, _ = ops.rms_norm_per_block_quant(
x, rms_norm_layer.weight, 1e-6, quant_dtype, group_size, residual=residual
)


# Bench functions
def bench_fn(
rms_norm_layer: RMSNorm,
x: torch.Tensor,
residual: torch.Tensor,
quant_dtype: torch.dtype,
group_size: list[int],
label: str,
sub_label: str,
fn: Callable,
Expand All @@ -110,10 +152,11 @@ def bench_fn(
"x": x,
"residual": residual,
"quant_dtype": quant_dtype,
"group_size": group_size,
"fn": fn,
}
return TBenchmark.Timer(
stmt="fn(rms_norm_layer, x, residual, quant_dtype)",
stmt="fn(rms_norm_layer, x, residual, quant_dtype, group_size)",
globals=globals,
label=label,
sub_label=sub_label,
Expand Down Expand Up @@ -147,6 +190,7 @@ def bench(params: bench_params_t, label: str, sub_label: str) -> Iterable[TMeasu
x,
residual,
torch.int8,
params.group_size,
label,
sub_label,
unfused_int8_impl,
Expand All @@ -161,6 +205,7 @@ def bench(params: bench_params_t, label: str, sub_label: str) -> Iterable[TMeasu
x,
residual,
torch.float8_e4m3fn,
params.group_size,
label,
sub_label,
unfused_fp8_impl,
Expand All @@ -175,6 +220,7 @@ def bench(params: bench_params_t, label: str, sub_label: str) -> Iterable[TMeasu
x,
residual,
torch.int8,
params.group_size,
label,
sub_label,
fused_impl,
Expand All @@ -189,13 +235,44 @@ def bench(params: bench_params_t, label: str, sub_label: str) -> Iterable[TMeasu
x,
residual,
torch.float8_e4m3fn,
params.group_size,
label,
sub_label,
fused_impl,
"fused_fp8_impl",
)
)

# unfused groupwise fp8 impl.
timers.append(
bench_fn(
layer,
x,
residual,
torch.float8_e4m3fn,
params.group_size,
label,
sub_label,
unfused_groupwise_fp8_impl,
"unfused_groupwise_fp8_impl",
)
)

# fused groupwise fp8 impl.
timers.append(
bench_fn(
layer,
x,
residual,
torch.float8_e4m3fn,
params.group_size,
label,
sub_label,
fused_groupwise_impl,
"fused_groupwise_fp8_impl",
)
)

print_timers(timers)

return timers
Expand Down
7 changes: 7 additions & 0 deletions csrc/ops.h
Original file line number Diff line number Diff line change
Expand Up @@ -123,6 +123,13 @@ void rms_norm_dynamic_per_token_quant(torch::Tensor& out,
std::optional<torch::Tensor> scale_ub,
std::optional<torch::Tensor> residual);

void rms_norm_per_block_quant(torch::Tensor& out, torch::Tensor const& input,
torch::Tensor const& weight,
torch::Tensor& scales, double const epsilon,
std::optional<torch::Tensor> scale_ub,
std::optional<torch::Tensor> residual,
int64_t group_size);

void rotary_embedding(torch::Tensor& positions, torch::Tensor& query,
std::optional<torch::Tensor> key, int64_t head_size,
torch::Tensor& cos_sin_cache, bool is_neox);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,14 +31,15 @@ __device__ void rms_norm_dynamic_per_token_quant_vec(

// RMS Norm + Quant
if constexpr (std::is_same_v<scalar_out_t, int8_t>) {
token_scale = 1.0f / token_scale;
vllm::vectorized::norm_and_quant<scalar_t, scalar_out_t, true,
has_residual>(
out, input, weight, rms, 1.0f / token_scale, hidden_size, residual);
out, input, weight, rms, &token_scale, hidden_size, residual);
} else {
// FP8 - Do not invert token_scale for exact match with FBGemm
vllm::vectorized::norm_and_quant<scalar_t, scalar_out_t, false,
has_residual>(
out, input, weight, rms, token_scale, hidden_size, residual);
out, input, weight, rms, &token_scale, hidden_size, residual);
}
}

Expand Down Expand Up @@ -75,14 +76,53 @@ __global__ void rms_norm_dynamic_per_token_quant_kernel(

// RMS Norm + Quant
if constexpr (std::is_same_v<scalar_out_t, int8_t>) {
token_scale = 1.0f / token_scale;
vllm::norm_and_quant<scalar_t, scalar_out_t, true, has_residual>(
out, input, weight, rms, 1.0f / token_scale, hidden_size, residual);
out, input, weight, rms, &token_scale, hidden_size, residual);
} else {
// FP8 - Do not invert s_token_scale for exact match with FBGemm
vllm::norm_and_quant<scalar_t, scalar_out_t, false, has_residual>(
out, input, weight, rms, token_scale, hidden_size, residual);
out, input, weight, rms, &token_scale, hidden_size, residual);
}
}

// RMS norm + quant kernel
template <typename scalar_t, typename scalar_out_t, bool has_residual = false>
__global__ void rms_norm_per_block_quant_kernel(
float* __restrict__ rms,
scalar_out_t* __restrict__ out, // [..., hidden_size]
float* __restrict__ scales, // [num_tokens, hidden_size / group_size]
scalar_t const* __restrict__ input, // [..., hidden_size]
scalar_t const* __restrict__ weight, // [hidden_size]
float* __restrict__ token_scale, float const* scale_ub,
float const var_epsilon, int32_t const hidden_size,
scalar_t* __restrict__ residual = nullptr, int32_t const group_size = 0) {
// Compute RMS
// Always able to vectorize due to constraints on hidden_size
vllm::vectorized::compute_rms<scalar_t, has_residual>(
rms + blockIdx.x, input, hidden_size, var_epsilon, residual);

// Compute Scale
// Always able to vectorize due to constraints on hidden_size and group_size
vllm::vectorized::compute_dynamic_per_token_scales<scalar_t, scalar_out_t,
has_residual>(
token_scale, scales, input, weight, rms[blockIdx.x], scale_ub,
hidden_size, residual, group_size);

// RMS Norm + Quant
// Always able to vectorize due to constraints on hidden_size
int token_idx = blockIdx.x * hidden_size / group_size;
// For int8, don't invert token_scale here: do it inside the norm_and_quant
// kernel. We do it because particular elements of token_scale can be shared
// between multiple threads, so this way, we avoid extra synchronization
// overhead.
vllm::vectorized::norm_and_quant<scalar_t, scalar_out_t,
std::is_same_v<scalar_out_t, int8_t>,
has_residual>(
out, input, weight, rms[blockIdx.x], token_scale + token_idx, hidden_size,
residual, group_size);
}

} // namespace vllm

// Residual add + RMS norm + dynamic per token
Expand Down Expand Up @@ -157,3 +197,84 @@ void rms_norm_dynamic_per_token_quant(
out, input, weight, scales, var_epsilon, scale_ub, residual);
});
}

// Residual add + RMS norm + dynamic per token
// TODO think up better names than kernel_1, kernel_2, kernel_3, cleanup args
// TODO vectorized kernels
template <typename scalar_in_t>
void rms_norm_per_block_quant_dispatch(
torch::Tensor& out, // [..., hidden_size]
torch::Tensor const& input, // [..., hidden_size]
torch::Tensor const& weight, // [hidden_size]
torch::Tensor& scales, // [num_tokens, hidden_size / group_size]
double const var_epsilon, // Variance epsilon used in norm calculation
std::optional<at::Tensor> const& scale_ub,
std::optional<at::Tensor>& residual, int64_t group_size) {
int32_t hidden_size = input.size(-1);
auto num_tokens = input.numel() / hidden_size;

dim3 grid13(num_tokens);
dim3 block13(std::min(hidden_size, 512));
const at::cuda::OptionalCUDAGuard device_guard(device_of(input));
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();

auto const fp_options =
torch::TensorOptions().dtype(torch::kFloat32).device(input.device());
torch::Tensor rms = torch::empty({num_tokens}, fp_options);
torch::Tensor token_scale =
torch::empty({num_tokens * hidden_size / group_size}, fp_options);

if (residual.has_value()) {
VLLM_DISPATCH_QUANT_TYPES(
out.scalar_type(), "rms_norm_per_block_quant_kernel", [&] {
vllm::rms_norm_per_block_quant_kernel<scalar_in_t, scalar_t, true>
<<<grid13, block13, 0, stream>>>(
rms.data_ptr<float>(), out.data_ptr<scalar_t>(),
scales.data_ptr<float>(), input.data_ptr<scalar_in_t>(),
weight.data_ptr<scalar_in_t>(), token_scale.data_ptr<float>(),
scale_ub.has_value() ? scale_ub->data_ptr<float>() : nullptr,
var_epsilon, hidden_size, residual->data_ptr<scalar_in_t>(),
group_size);
});
} else {
VLLM_DISPATCH_QUANT_TYPES(
out.scalar_type(), "rms_norm_per_block_quant_kernel", [&] {
vllm::rms_norm_per_block_quant_kernel<scalar_in_t, scalar_t, false>
<<<grid13, block13, 0, stream>>>(
rms.data_ptr<float>(), out.data_ptr<scalar_t>(),
scales.data_ptr<float>(), input.data_ptr<scalar_in_t>(),
weight.data_ptr<scalar_in_t>(), token_scale.data_ptr<float>(),
scale_ub.has_value() ? scale_ub->data_ptr<float>() : nullptr,
var_epsilon, hidden_size, nullptr, group_size);
});
}
}

void rms_norm_per_block_quant(torch::Tensor& out, torch::Tensor const& input,
torch::Tensor const& weight,
torch::Tensor& scales, double const var_epsilon,
std::optional<torch::Tensor> scale_ub,
std::optional<torch::Tensor> residual,
int64_t group_size) {
static c10::ScalarType kFp8Type = is_fp8_ocp()
? c10::ScalarType::Float8_e4m3fn
: c10::ScalarType::Float8_e4m3fnuz;
TORCH_CHECK(out.dtype() == kFp8Type || out.dtype() == torch::kInt8);
TORCH_CHECK(out.is_contiguous() && input.is_contiguous());

if (scale_ub.has_value()) {
TORCH_CHECK(out.dtype() == kFp8Type);
}
TORCH_CHECK(weight.dtype() == input.dtype());
TORCH_CHECK(scales.dtype() == torch::kFloat32);
if (residual) {
TORCH_CHECK(residual->scalar_type() == input.scalar_type());
}

VLLM_DISPATCH_FLOATING_TYPES(
input.scalar_type(), "rms_norm_per_block_quant_dispatch", [&] {
rms_norm_per_block_quant_dispatch<scalar_t>(out, input, weight, scales,
var_epsilon, scale_ub,
residual, group_size);
});
}
Loading