Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
59 changes: 59 additions & 0 deletions llmc/gelu.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,47 @@ __global__ void gelu_backward_inplace_kernel(floatX* d_in_out, const floatX* inp
store128(d_in_out + idx, packed_dinp);
}

__global__ void swiglu_forward_kernel(floatX* out, const floatX* inp1, const floatX* inp2) {
int idx = (blockIdx.x * blockDim.x + threadIdx.x) * x128::size;

x128 packed_out;
x128 packed_inp1 = load128cs(inp1 + idx); // load and do not keep in cache
x128 packed_inp2 = load128cs(inp2 + idx);
for(int k = 0; k < packed_inp1.size; ++k) {
float x1 = (float)packed_inp1[k];
float x2 = (float)packed_inp2[k];
// swish(x1) = x1 * sigmoid(x1) = x1 / (1.0 + exp(-x1))
// swiglu(x1, x2) = swish(x1) * x2
packed_out[k] = (floatX)((x1 * x2) / (1.0f + expf(-x1)));
}
store128(out + idx, packed_out);
}

__global__ void swiglu_backward_inplace_kernel(floatX* dinp_out1, floatX* dinp2, const floatX* inp1, const floatX* inp2) {
int idx = (blockIdx.x * blockDim.x + threadIdx.x) * x128::size;

x128 packed_dinp1;
x128 packed_dinp2;
x128 packed_inp1 = load128cs(inp1 + idx);
x128 packed_inp2 = load128cs(inp2 + idx);
x128 packed_dinp_out1 = load128(dinp_out1 + idx);
for (int k = 0; k < packed_inp1.size; ++k) {
float x1 = (float)packed_inp1[k];
float x2 = (float)packed_inp2[k];
float sig_x1 = 1.0f / (1.0f + expf(-x1));
// swiglu(x1, x2) = swish(x1) * x2
// -> dout/dx1 = x2 * sigmoid(x1) + x2 * x1 * sigmoid(x1) * (1 - sigmoid(x1))
// ---> dout/dx1 = x2 * sigmoid(x1) * (1 + x1 * (1 - sigmoid(x1)))
// -> dout/dx2 = swish(x1) = x1 * sigmoid(x1)
float local_grad1 = x2 * sig_x1 * (1.0f + x1 * (1.0f - sig_x1));
float local_grad2 = x1 * sig_x1;
packed_dinp1[k] = (floatX)(local_grad1 * (float)packed_dinp_out1[k]);
packed_dinp2[k] = (floatX)(local_grad2 * (float)packed_dinp_out1[k]);
}
store128(dinp_out1 + idx, packed_dinp1);
store128(dinp2 + idx, packed_dinp2);
}

// ----------------------------------------------------------------------------
// kernel launchers

Expand All @@ -64,3 +105,21 @@ void gelu_backward_inplace(floatX* d_in_out, const floatX* inp, const int N, cud
gelu_backward_inplace_kernel<<<grid_size, block_size, 0, stream>>>(d_in_out, inp);
cudaCheck(cudaGetLastError());
}

void swiglu_forward(floatX* out, const floatX* inp1, const floatX* inp2, int N, cudaStream_t stream) {
NVTX_RANGE_FN();
const int block_size = 512;
assert(N % (block_size * x128::size) == 0);
const int grid_size = CEIL_DIV(N, block_size * x128::size);
swiglu_forward_kernel<<<grid_size, block_size, 0, stream>>>(out, inp1, inp2);
cudaCheck(cudaGetLastError());
}

void swiglu_backward_inplace(floatX* dinp_out1, floatX* dinp2, const floatX* inp1, const floatX* inp2, const int N, cudaStream_t stream) {
NVTX_RANGE_FN();
const int block_size = 128;
assert(N % (block_size * x128::size) == 0);
const int grid_size = CEIL_DIV(N, block_size * x128::size);
swiglu_backward_inplace_kernel<<<grid_size, block_size, 0, stream>>>(dinp_out1, dinp2, inp1, inp2);
cudaCheck(cudaGetLastError());
}
50 changes: 37 additions & 13 deletions llmc/matmul.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -231,21 +231,40 @@ void matmul_cublaslt(floatX* d, const floatX* a, const floatX* b, const floatX*
void matmul_forward_cublaslt(floatX* out,
floatX* inp, floatX* weight, floatX* bias,
int B, int T, int C, int OC, cudaStream_t stream,
floatX* pre_gelu=NULL, int gelu_fusion=1) {
// By default only fuse GELU for H100+ as cuBLAS seems to be inefficient for fused GELU on Ada/Ampere (?)
if (gelu_fusion < 1 && pre_gelu) {
matmul_cublaslt(pre_gelu, weight, inp, bias, OC, B*T, C, stream, true, false, 0, 0, 0, 0, false, NULL, false);
gelu_forward(out, pre_gelu, B*T*OC, stream);
const char* act_func, floatX* pre_act=NULL, int act_func_fusion=1) {
// By default only fuse GELU/{act_func} for H100+ as cuBLAS seems to be inefficient for fused GELU on Ada/Ampere (?)
if (act_func_fusion < 1 && pre_act) {
assert(strcmp(act_func, "gelu") == 0);
matmul_cublaslt(pre_act, weight, inp, bias, OC, B*T, C, stream, true, false, 0, 0, 0, 0, false, NULL, false);
gelu_forward(out, pre_act, B*T*OC, stream);
} else {
matmul_cublaslt(out, weight, inp, bias, OC, B*T, C, stream, true, false, 0, 0, 0, 0, false, pre_gelu, false);
if (pre_act != NULL) {
assert(strcmp(act_func, "gelu") == 0); // currently only GELU is supported for fusion
}
matmul_cublaslt(out, weight, inp, bias, OC, B*T, C, stream, true, false, 0, 0, 0, 0, false, pre_act, false);
}
}

void matmul_forward_fc1(floatX* out,
floatX* inp, floatX* weight1, floatX* bias1, floatX* weight2, floatX* bias2,
int B, int T, int C, int OC, cudaStream_t stream,
const char* act_func, floatX* pre_act1=NULL, floatX* pre_act2=NULL, int act_func_fusion=1) {
if (weight2 == NULL) {
assert(bias2 == NULL);
matmul_forward_cublaslt(out, inp, weight1, bias1, B, T, C, OC, stream, act_func, pre_act1, act_func_fusion);
} else {
assert(strcmp(act_func, "swiglu") == 0);
matmul_cublaslt(pre_act1, weight1, inp, bias1, OC, B*T, C, stream, true, false, 0, 0, 0, 0, false, NULL, false);
matmul_cublaslt(pre_act2, weight2, inp, bias2, OC, B*T, C, stream, true, false, 0, 0, 0, 0, false, NULL, false);
swiglu_forward(out, pre_act1, pre_act2, B*T*OC, stream);
}
}

void matmul_backward(floatX* dinp, floatX* dweight, floatX* dbias,
void matmul_backward(floatX* dinp1, floatX* dinp2, floatX* dweight, floatX* dbias,
floatX* dout, floatX* inp, floatX* weight,
float* dbias_buffer,
int B, int T, int C, int OC, cudaStream_t stream,
floatX* pre_gelu=NULL, int gelu_fusion=1) {
int B, int T, int C, int OC, cudaStream_t stream, int accumulate_input,
const char* act_func, floatX* pre_act1=NULL, floatX* pre_act2=NULL, int gelu_fusion=1) {
NVTX_RANGE_FN();

// backward to bias, if given, does a +=
Expand Down Expand Up @@ -275,13 +294,18 @@ void matmul_backward(floatX* dinp, floatX* dweight, floatX* dbias,
dbias = NULL; // prevent dbias calculation from also being fused in matmul_cublaslt below (if we enabled fusion)
}

int is_gelu = strcmp(act_func, "gelu") == 0;

// backward to input, uses = in the backward pass (set the gradient)
matmul_cublaslt(dinp, weight, dout, NULL, C, B*T, OC, stream, false, false, 0, 0, 0, 0, false,
gelu_fusion >= 2 ? pre_gelu : NULL, true);
matmul_cublaslt(dinp1, weight, dout, NULL, C, B*T, OC, stream, false, false, 0, 0, 0, 0,
accumulate_input /* accumulate */, is_gelu && gelu_fusion >= 2 ? pre_act1 : NULL, true);

// backward GELU (if it wasn't fused into the matmul above)
if (gelu_fusion < 2 && pre_gelu) {
gelu_backward_inplace(dinp, pre_gelu, B*T*C, stream);
if (is_gelu && gelu_fusion < 2 && pre_act1) {
gelu_backward_inplace(dinp1, pre_act1, B*T*C, stream);
} else if (!is_gelu && pre_act1) {
assert(strcmp(act_func, "swiglu") == 0);
swiglu_backward_inplace(dinp1, dinp2, pre_act1, pre_act2, B*T*C, stream);
}

// backward to weight, uses += in the backward pass (accumulate the gradient) by setting alpha=one
Expand Down
1 change: 1 addition & 0 deletions llmc/zero.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -529,6 +529,7 @@ void multi_gpu_async_reduce_gradient(
cudaCheck(cudaStreamWaitEvent(config->nccl_stream, config->compute_nccl_sync));
ncclCheck(ncclGroupStart()); // NCCL group: aggregate all pointers in a single NCCL GPU kernel.
for (int i = 0; i < N; ++i) {
if (pointers[i] == NULL) continue;
if(config->zero_stage == 0) {
ncclCheck(ncclAllReduce(
pointers[i], pointers[i],
Expand Down
5 changes: 4 additions & 1 deletion test_llama3.cu
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,8 @@ typedef struct {
float* ln2b; // (L, C)
float* fcw; // (L, 4*C, C)
float* fcb; // (L, 4*C)
float* gatew; // (L, 4*C, C)
float* gateb; // (L, 4*C)
float* fcprojw; // (L, C, 4*C)
float* fcprojb; // (L, C)
float* lnfw; // (C)
Expand All @@ -78,6 +80,7 @@ float* float_cpu_malloc_and_point_parameters(FloatParameterTensors* params, size
float** ptrs[] = {
&params->wte, &params->wpe, &params->ln1w, &params->ln1b, &params->qkvw, &params->qkvb,
&params->attprojw, &params->attprojb, &params->ln2w, &params->ln2b, &params->fcw, &params->fcb,
&params->gatew, &params->gateb,
&params->fcprojw, &params->fcprojb, &params->lnfw, &params->lnfb
};
float* params_memory_iterator = params_memory;
Expand Down Expand Up @@ -122,7 +125,7 @@ int main(int argc, char *argv[]) {
if (argv[i][0] != '-') { exit(EXIT_FAILURE); } // must start with dash
if (argv[i][1] == 'w') { model.use_master_weights = atoi(argv[i+1]); }
else if (argv[i][1] == 'r') { model.recompute = atoi(argv[i+1]); }
else if (argv[i][1] == 'g' && argv[i][2] == 'e') { model.gelu_fusion = atoi(argv[i+1]); }
else if (argv[i][1] == 'g' && argv[i][2] == 'e') { model.act_func_fusion = atoi(argv[i+1]); }
}

// load additional information that we will use for debugging and error checking
Expand Down
Loading