Skip to content

Commit b299fe7

Browse files
authored
Simplify GEMM interface (#3818)
* simplify gemm interface * fix lint * fix lint
1 parent 4269dfd commit b299fe7

26 files changed

+585
-469
lines changed

src/turbomind/kernels/gemm/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@ add_library(gemm2
2727
kernel/sm80_s16816_dynamic.cu
2828
kernel/sm90_s16816_dynamic.cu
2929
kernel/sm90_q64n32.cu
30+
cublas.cu
3031
moe_utils_v2.cu
3132
test/test_utils.cu
3233
)

src/turbomind/kernels/gemm/arch.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,8 @@ struct Sm90: Arch<900> {
3333
inline bool is_arch_compatible(int karch, int darch)
3434
{
3535
switch (karch) {
36+
case 0:
37+
return true;
3638
case 700:
3739
return Sm70::is_compatible(darch);
3840
case 750:

src/turbomind/kernels/gemm/context.cu

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -136,6 +136,10 @@ std::vector<Kernel*> StaticGemmContext::Filter(const std::vector<Kernel*>& kerne
136136

137137
std::vector<LaunchSpec> StaticGemmContext::Populate(const Kernel& kernel, const PopulateParam& param) const
138138
{
139+
if (kernel.desc().backend) {
140+
return {LaunchSpec{const_cast<Kernel*>(&kernel), 0, 1}};
141+
}
142+
139143
const int m = desc_->m, n = desc_->n, k = desc_->k;
140144

141145
const auto& desc = kernel.desc();
Lines changed: 151 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,151 @@
1+
#include <cublas_v2.h>
2+
3+
#include "src/turbomind/core/cuda_data_type.h"
4+
#include "src/turbomind/core/data_type.h"
5+
6+
#include "src/turbomind/kernels/gemm/kernel.h"
7+
#include "src/turbomind/kernels/gemm/registry.h"
8+
#include "src/turbomind/kernels/gemm/types.h"
9+
10+
namespace turbomind::gemm {
11+
12+
class CublasKernel: public Kernel {
13+
public:
14+
explicit CublasKernel()
15+
{
16+
cublasCreate(&cublas_);
17+
if (0) {
18+
cublasSetMathMode(cublas_, CUBLAS_MATH_DISALLOW_REDUCED_PRECISION_REDUCTION);
19+
}
20+
desc_ = {};
21+
desc_.backend = 1;
22+
name_ = GetName();
23+
}
24+
25+
int Launch(const Operation& operation,
26+
float alpha,
27+
const void* A,
28+
const MatrixLayout& Adesc,
29+
const void* U,
30+
const MatrixLayout& Udesc,
31+
const void* B,
32+
const MatrixLayout& Bdesc,
33+
const void* V,
34+
const MatrixLayout& Vdesc,
35+
float beta,
36+
const void* C,
37+
const MatrixLayout& Cdesc,
38+
void* D,
39+
const MatrixLayout& Ddesc,
40+
int swizzle,
41+
int splits,
42+
Workspace& workspace,
43+
cudaStream_t stream) override
44+
{
45+
cublasOperation_t transa = Adesc.order == kColMajor ? CUBLAS_OP_N : CUBLAS_OP_T;
46+
cublasOperation_t transb = Bdesc.order == kColMajor ? CUBLAS_OP_N : CUBLAS_OP_T;
47+
48+
const int m = Adesc.rows;
49+
const int n = Bdesc.cols;
50+
const int k = Adesc.cols;
51+
52+
TM_CHECK_EQ(Bdesc.rows, k);
53+
TM_CHECK_EQ(Ddesc.rows, m);
54+
TM_CHECK_EQ(Ddesc.cols, n);
55+
56+
TM_CHECK(C == nullptr || C == D);
57+
58+
if (stream_ != stream) {
59+
cublasSetStream(cublas_, stream);
60+
stream_ = stream;
61+
}
62+
63+
if (workspace_ != workspace.partials || workspace_size_ != workspace.partials_size) {
64+
cublasSetWorkspace(cublas_, workspace.partials, workspace.partials_size);
65+
workspace_ = workspace.partials;
66+
workspace_size_ = workspace.partials_size;
67+
}
68+
69+
auto ec = cublasGemmEx(cublas_,
70+
transa,
71+
transb,
72+
m,
73+
n,
74+
k,
75+
&alpha,
76+
A,
77+
to_cuda_dtype(Adesc.type),
78+
Adesc.ld,
79+
B,
80+
to_cuda_dtype(Bdesc.type),
81+
Bdesc.ld,
82+
&beta,
83+
D,
84+
to_cuda_dtype(Ddesc.type),
85+
Ddesc.ld,
86+
CUDA_R_32F,
87+
CUBLAS_GEMM_DEFAULT_TENSOR_OP);
88+
89+
return ec == CUBLAS_STATUS_SUCCESS ? 0 : 1;
90+
}
91+
92+
bool is_feasible(const GemmDesc& desc) const noexcept override
93+
{
94+
constexpr std::tuple flat3{Striding::kFlat, Striding::kFlat, Striding::kFlat};
95+
96+
if (std::tie(desc.striding_a, desc.striding_b, desc.striding_c) != flat3) {
97+
return false;
98+
}
99+
if (std::tie(desc.pack_a, desc.pack_b, desc.pack_u, desc.pack_v) != std::tuple{0, 0, 0, 0}) {
100+
return false;
101+
}
102+
if (desc.epilogue != Epilogue::kNone) {
103+
return false;
104+
}
105+
if (desc.num > 1) {
106+
return false;
107+
}
108+
if (desc.quant_a || desc.quant_b) {
109+
return false;
110+
}
111+
if (desc.sched) {
112+
return false;
113+
}
114+
if (desc.order_c != kColMajor) {
115+
return false;
116+
}
117+
if (desc.type_a != kHalf && desc.type_a != kBfloat16 && desc.type_a != kFloat) {
118+
return false;
119+
}
120+
if (desc.type_b != desc.type_a) {
121+
return false;
122+
}
123+
if (desc.type_c != desc.type_a && desc.type_c != kFloat) {
124+
return false;
125+
}
126+
return true;
127+
}
128+
129+
int GetMaxSplits(const int4&, int64_t, size_t, size_t) const override
130+
{
131+
return 1;
132+
}
133+
134+
int GetSwizzle(int m, int n, int k, int splits, int swizzle) const override
135+
{
136+
return 0;
137+
}
138+
139+
private:
140+
cublasHandle_t cublas_{};
141+
cudaStream_t stream_{};
142+
void* workspace_{};
143+
size_t workspace_size_{};
144+
};
145+
146+
void Registry::cublas_float()
147+
{
148+
Add(std::make_unique<CublasKernel>());
149+
}
150+
151+
} // namespace turbomind::gemm

src/turbomind/kernels/gemm/desc.h

Lines changed: 71 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,9 +2,11 @@
22

33
#pragma once
44

5+
#include <array>
6+
#include <tuple>
7+
58
#include "src/turbomind/kernels/core/data_type.h"
69
#include "src/turbomind/kernels/gemm/types.h"
7-
#include <array>
810

911
namespace turbomind::gemm {
1012

@@ -47,9 +49,36 @@ inline GemmDesc transpose(GemmDesc d)
4749
std::swap(d.pack_u, d.pack_v);
4850
std::swap(d.quant_a, d.quant_b);
4951
std::swap(d.m, d.n);
52+
d.batch_dim = 1 - d.batch_dim;
5053
return d;
5154
}
5255

56+
inline std::string to_string(const GemmDesc& d)
57+
{
58+
std::stringstream ss;
59+
ss << "sm" << d.arch / 10;
60+
ss << "_" << to_string(d.type_a); //
61+
if (d.quant_a) {
62+
ss << to_string(d.quant_a);
63+
}
64+
ss << "_" << to_string(d.type_b); //
65+
if (d.quant_b) {
66+
ss << to_string(d.quant_b);
67+
}
68+
ss << "_" << to_string(d.type_c);
69+
ss << "_" //
70+
<< (d.order_a == kColMajor ? 'n' : 't') //
71+
<< (d.order_b == kColMajor ? 'n' : 't') //
72+
<< (d.order_c == kColMajor ? 'n' : 't'); //
73+
ss << "_" //
74+
<< to_string(d.striding_a) //
75+
<< to_string(d.striding_b) //
76+
<< to_string(d.striding_c);
77+
ss << "_" << d.m << "x" << d.n << "x" << d.k;
78+
ss << "_" << d.num;
79+
return ss.str();
80+
}
81+
5382
enum class OpClass
5483
{
5584
kSIMT,
@@ -101,12 +130,53 @@ struct KernelDesc {
101130
int stages;
102131
bool split_k;
103132
int sched;
133+
int backend;
134+
bool transpose;
104135

105136
// set by `KernelImpl`
106137
int max_active_ctas;
107138
cudaFuncAttributes attr;
108139
};
109140

141+
inline KernelDesc transpose(const KernelDesc& d)
142+
{
143+
KernelDesc k{d};
144+
145+
k.arch = d.arch;
146+
k.op_class = d.op_class;
147+
148+
k.order_a = ~d.order_b;
149+
k.order_b = ~d.order_a;
150+
k.order_c = ~d.order_c;
151+
152+
k.type_a = d.type_b;
153+
k.type_b = d.type_a;
154+
155+
k.striding_a = d.striding_b;
156+
k.striding_b = d.striding_a;
157+
158+
k.pack_a = d.pack_b;
159+
k.pack_b = d.pack_a;
160+
k.pack_u = d.pack_v;
161+
k.pack_v = d.pack_u;
162+
163+
k.quant_a = d.quant_b;
164+
k.quant_b = d.quant_a;
165+
166+
k.policy_a = d.policy_b;
167+
k.policy_b = d.policy_a;
168+
169+
auto swap = [](auto& v) { std::swap(v.x, v.y); };
170+
171+
swap(k.cta_tile);
172+
swap(k.mma_tile);
173+
swap(k.cluster_shape);
174+
swap(k.align);
175+
swap(k.c_tile);
176+
177+
return k;
178+
}
179+
110180
class Kernel;
111181
struct LaunchSpec {
112182
Kernel* kernel;

src/turbomind/kernels/gemm/dispatch_cache.cu

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,8 @@ static inline decltype(auto) as_tuple(const KernelDesc& d)
5252
d.c_tile,
5353
d.stages,
5454
d.split_k,
55+
d.backend,
56+
d.transpose,
5557
d.sched);
5658
}
5759

@@ -139,7 +141,9 @@ void ExportDispatchCache(std::ostream& os, const std::vector<std::pair<GemmDesc,
139141
k.policy_b,
140142
k.c_tile.x,
141143
k.c_tile.y,
142-
k.split_k);
144+
k.split_k,
145+
k.backend,
146+
k.transpose);
143147
// Runtime params
144148
export_impl(os, spec.swizzle, spec.splits);
145149
os << std::endl;
@@ -217,7 +221,9 @@ void ImportDispatchCache(std::istream& is,
217221
k.policy_b,
218222
k.c_tile.x,
219223
k.c_tile.y,
220-
k.split_k);
224+
k.split_k,
225+
k.backend,
226+
k.transpose);
221227
LaunchSpec spec{};
222228
import_impl(ss, spec.swizzle, spec.splits);
223229
for (const auto& p : kernels) {

src/turbomind/kernels/gemm/gemm.cu

Lines changed: 2 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -277,6 +277,7 @@ int Gemm::Run(const Operation& operation,
277277

278278
if (!desc) {
279279
fprintf(stderr, "invalid argument.\n");
280+
TM_CHECK(0);
280281
return -1;
281282
}
282283

@@ -330,41 +331,7 @@ int Gemm::Run(const Operation& operation,
330331
return launch(spec, stream);
331332
}
332333

333-
const auto launch1 = [=](LaunchSpec spec, cudaStream_t st) {
334-
auto _workspace = workspace;
335-
return spec.kernel->Launch(operation,
336-
alpha,
337-
B,
338-
transpose(Bdesc),
339-
V,
340-
transpose(Vdesc),
341-
A,
342-
transpose(Adesc),
343-
U,
344-
transpose(Udesc),
345-
beta,
346-
C,
347-
transpose(Cdesc),
348-
D,
349-
transpose(Ddesc),
350-
spec.swizzle,
351-
spec.splits,
352-
_workspace,
353-
stream);
354-
};
355-
356-
if (operation.dispatch & DispatchPolicy::kMeasure) {
357-
impl_->Measure(context, transpose(*desc), workspace.barriers_size, workspace.partials_size, 1, launch1, stream);
358-
}
359-
360-
spec = impl_->Dispatch(
361-
context, operation.dispatch, transpose(*desc), workspace.barriers_size, workspace.partials_size);
362-
363-
if (spec.kernel) {
364-
return launch1(spec, stream);
365-
}
366-
367-
fprintf(stderr, "No feasible kernel found for the problem.\n");
334+
TM_CHECK(0) << "No feasible kernel found for the problem: " << to_string(*desc);
368335

369336
return -1;
370337
}

0 commit comments

Comments
 (0)