Skip to content

Commit b77f157

Browse files
authored
Dispatch MXFP4 weight conversion for sm70 & sm75 (#3937)
* simplify weight conversion dispatch * fix sm70 window attention
1 parent 186606c commit b77f157

File tree

9 files changed

+198
-250
lines changed

9 files changed

+198
-250
lines changed

src/turbomind/kernels/attention/mainloop_sm70.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -107,6 +107,8 @@ struct Mainloop<arch::Sm70, Impl_> {
107107
Impl::ComputePV(state_PV, frag_O, 0, nop, [&] {});
108108

109109
gmem_K.Save(tmp_K);
110+
111+
offset_K -= CTA_S;
110112
};
111113

112114
for (int mask_iter = max(1, mask_iter_back); tile_iter > 0 && mask_iter > 0; --tile_iter, --mask_iter) {

src/turbomind/kernels/core/array_ops.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -186,6 +186,9 @@ inline __device__ void Store(T* dst, const Array<T, N>& src)
186186
else if constexpr (sizeof(Array<T, N>) == sizeof(ushort)) {
187187
*(ushort*)dst = (const ushort&)src;
188188
}
189+
else if constexpr (sizeof(Array<T, N>) == sizeof(char)) {
190+
*(char*)dst = (const char&)src;
191+
}
189192
else if constexpr (sizeof(Array<T, N>) % sizeof(uint4) == 0) { // uncoalesced
190193
static_assert(bitsof<T> % 8 == 0, "raw pointer arithmetic of sub-byte types");
191194
constexpr int M = sizeof(Array<T, N>) / sizeof(uint4);

src/turbomind/kernels/gemm/CMakeLists.txt

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -20,8 +20,6 @@ add_library(gemm2
2020
kernel/f16_u4g128_f16_tnt_sm80_s16816.cu
2121
kernel/f16_u4g128_f16_tnt_sm75_s16816.cu
2222
kernel/f16_u4g128_f16_tnt_sm70_s884.cu
23-
# kernel/f16_u4g128_f16_tnt_sm75_simt.cu
24-
# kernel/u4g128_f16_f16_nnn_sm80_s16816.cu
2523
kernel/sm90_mxfp4.cu
2624
kernel/sm80_mxfp4.cu
2725
kernel/sm70_s884_dynamic.cu

src/turbomind/kernels/gemm/convert_v2.cu

Lines changed: 2 additions & 61 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
// Copyright (c) OpenMMLab. All rights reserved.
22

3+
#include "src/turbomind/core/data_type.h"
34
#include "src/turbomind/kernels/attention/quantization.h"
45
#include "src/turbomind/kernels/core/common.h"
56
#include "src/turbomind/kernels/core/math.h"
@@ -124,7 +125,7 @@ int Convert(const void* S, //
124125
static constexpr bool kIsValid = kPackSize % unit_size(type_c<Dtype>) == 0;
125126
constexpr Pack pack = mma | operand | pack_num;
126127

127-
if constexpr (kIsValid || operand == OPERAND_U) {
128+
if constexpr (kIsValid || is_UV(operand)) {
128129
// Launch conversion kernel
129130
Convert_v2_Impl<Config<Operand, Dtype, pack_num_tag>>(S, Sdesc, D, Ddesc, stream);
130131
// Set leading dimension for destination
@@ -226,66 +227,6 @@ int Convert(const void* S, //
226227
return dispatch() - 1;
227228
}
228229

229-
std::tuple<Order, Pack, Order, Pack>
230-
get_weight_and_scales_layout(DataType dtype, bool is_fused_moe, int sm, bool force_simt)
231-
{
232-
if (is_fused_moe) {
233-
if (dtype == kBfloat16 && sm >= 80) {
234-
return {kColMajor, HMMA_16816 | OPERAND_B | 1, {}, {}};
235-
}
236-
237-
if (dtype == kFloat16) {
238-
if (sm >= 80) {
239-
return {kColMajor, HMMA_16816 | OPERAND_B | 1, {}, {}};
240-
}
241-
else if (sm == 75) {
242-
return {kColMajor, HMMA_16816 | OPERAND_B | 1, {}, {}};
243-
}
244-
else if (sm == 70) {
245-
return {kColMajor, HMMA_884 | OPERAND_B | 1, {}, {}};
246-
}
247-
}
248-
else if (dtype == kUint4) {
249-
if (sm >= 80) {
250-
return {kColMajor, HMMA_16816 | OPERAND_B | 2, kRowMajor, HMMA_16816 | OPERAND_V | 1};
251-
}
252-
else if (sm == 75) {
253-
return {kColMajor, HMMA_16816 | OPERAND_B | 2, kRowMajor, HMMA_16816 | OPERAND_V | 1};
254-
}
255-
else if (sm == 70) {
256-
return {kColMajor, HMMA_884 | OPERAND_B | 1, kRowMajor, HMMA_884 | OPERAND_V | 1};
257-
}
258-
}
259-
else if (dtype == kFloat4_e2m1) {
260-
if (sm >= 80) {
261-
return {kColMajor, HMMA_16816 | OPERAND_A | 1, kColMajor, HMMA_16816 | OPERAND_U | 1};
262-
}
263-
}
264-
}
265-
else {
266-
if (dtype == kUint4) {
267-
if (force_simt) {
268-
return {kColMajor, HMMA_SIMT | OPERAND_B | 1, kRowMajor, HMMA_SIMT | OPERAND_V | 1};
269-
}
270-
if (sm >= 80) {
271-
return {kRowMajor, HMMA_16816 | OPERAND_B | 2, kRowMajor, HMMA_16816 | OPERAND_V | 1};
272-
}
273-
else if (sm == 75) {
274-
return {kRowMajor, HMMA_16816 | OPERAND_B | 2, kRowMajor, HMMA_16816 | OPERAND_V | 1};
275-
}
276-
else if (sm == 70) {
277-
return {kColMajor, HMMA_884 | OPERAND_B | 1, kRowMajor, HMMA_884 | OPERAND_V | 1};
278-
}
279-
}
280-
}
281-
282-
std::cerr << "not implemented: dtype=" << to_string(dtype) << ", is_fused_moe=" << is_fused_moe << ", sm=" << sm
283-
<< std::endl;
284-
std::abort();
285-
286-
return {};
287-
}
288-
289230
namespace {
290231

291232
template<int N>

src/turbomind/kernels/gemm/kernel/sm70_s884_dynamic.cu

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -91,11 +91,11 @@ void Registry::sm70_s884_dynamic()
9191
0>;
9292

9393
// clang-format off
94-
Add<C::Type<128, 128, 16, 2, 2, 1, D, D, 2, true, 1, 128, 64, 128>>();
95-
Add<C::Type< 64, 128, 32, 1, 4, 1, D, S, 2, true, 1, 128, 32, 128>>();
96-
Add<C::Type< 32, 128, 32, 1, 4, 1, D, S, 2, true, 1, 128>>();
97-
Add<C::Type< 16, 128, 32, 1, 4, 1, D, S, 2, true, 1, 128>>();
98-
Add<C::Type< 8, 128, 64, 1, 4, 1, D, S, 2, true, 1, 128>>();
94+
Add<C::Type<128, 128, 16, 2, 2, 1, D, D, 2, true, 1, 32, 64, 128>>();
95+
Add<C::Type< 64, 128, 32, 1, 4, 1, D, S, 2, true, 1, 32, 32, 128>>();
96+
Add<C::Type< 32, 128, 32, 1, 4, 1, D, S, 2, true, 1, 32>>();
97+
Add<C::Type< 16, 128, 32, 1, 4, 1, D, S, 2, true, 1, 32>>();
98+
Add<C::Type< 8, 128, 64, 1, 4, 1, D, S, 2, true, 1, 32>>();
9999
// clang-format on
100100
}
101101
}

src/turbomind/kernels/gemm/registry.cu

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,6 @@ Registry::Registry(std::shared_ptr<cudaDeviceProp> device_prop):
99
device_prop_{std::move(device_prop)}, arch_{device_prop_->major * 100 + device_prop_->minor * 10}
1010
{
1111
f16_u4g128_f16_tnt_sm70_s884();
12-
// f16_u4g128_f16_tnt_sm75_simt();
1312
f16_u4g128_f16_tnt_sm75_s16816();
1413
f16_u4g128_f16_tnt_sm80_s16816();
1514
f16_u4g128_f16_tnt_sm90_s16816();

src/turbomind/kernels/gemm/test/reference.cu

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -71,15 +71,16 @@ void Reference::gemm(const void* A, MatrixLayout Adesc, const void* B, MatrixLay
7171
// (n, k) (k, m)
7272
}
7373

74-
CHECK(Adesc.cols == Bdesc.rows);
74+
TM_CHECK_EQ(Adesc.cols, Bdesc.rows);
7575

7676
// (m, k) (k, n)
7777
int m = Cdesc.rows;
7878
int n = Cdesc.cols;
7979
int k = Adesc.cols;
80-
CHECK(Adesc.rows == m);
81-
CHECK(Bdesc.cols == n);
82-
CHECK(Bdesc.rows == k);
80+
81+
TM_CHECK_EQ(Adesc.rows, m);
82+
TM_CHECK_EQ(Bdesc.cols, n);
83+
TM_CHECK_EQ(Bdesc.rows, k);
8384

8485
float alpha = 1.f;
8586
float beta = 0.f;

src/turbomind/kernels/gemm/test/test_gemm_v2.cc

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -24,10 +24,11 @@ int main()
2424

2525
core::ContextGuard ctx{stream, core::Allocator{kCPU}, core::Allocator{stream, false}};
2626
// TestParameter p{kBfloat16, kBfloat16, kBfloat16};
27+
// TestParameter p{kHalf, kHalf, kHalf};
2728
// TestParameter p{kBfloat16, kFloat8_e4m3, kFloat8_e4m3, 128};
28-
TestParameter p{kHalf, kUint4, kHalf, 128};
29+
// TestParameter p{kHalf, kUint4, kHalf, 128};
2930
// TestParameter p{kBfloat16, kFloat4_e2m1, kBfloat16, 32};
30-
// TestParameter p{kHalf, kFloat4_e2m1, kHalf, 32};
31+
TestParameter p{kHalf, kFloat4_e2m1, kHalf, 32};
3132

3233
// p.input_dim = 512;
3334
// p.output_dim = 1024;
@@ -61,10 +62,10 @@ int main()
6162
// p.experts_per_token = 8;
6263

6364
p.input_dim = 4096;
64-
p.output_dim = 4096;
65-
p.max_batch_size = 8;
66-
p.expert_num = 8;
67-
p.experts_per_token = 8;
65+
p.output_dim = 6144;
66+
p.max_batch_size = 512;
67+
p.expert_num = 32;
68+
p.experts_per_token = 4;
6869

6970
// p.input_dim = 32;
7071
// p.output_dim = 32;

0 commit comments

Comments
 (0)