diff --git a/llmc/attention.cuh b/llmc/attention.cuh index f6294a213..8993929cf 100644 --- a/llmc/attention.cuh +++ b/llmc/attention.cuh @@ -189,11 +189,55 @@ __global__ void softmax_autoregressive_backward_inplace_kernel(floatX* datt, con } } +__global__ void rope_rotate_kernel(floatX* q, floatX* k, float* rope_freqs, int B, int NH, int T, int HS, int is_backward) { + // thanks to the nice mathematical properties of RoPE this is both our fwd & bwd pass kernel! + // the only difference is that we have to toggle the sign of the sin term in the rotation + // q, k are of shape (B, NH, T, HS) + // rope_freqs is of shape (T, HS/2) + int n = HS / x128::size; + int thread_idx = blockIdx.x * blockDim.x + threadIdx.x; + if (thread_idx >= B * NH * T * n) { return; } + + int b = thread_idx / (NH * T * n); + int rest = thread_idx % (NH * T * n); + int nh = rest / (T * n); + rest = rest % (T * n); + int t = rest / n; + int i = rest % n; + + float* rope_freqs_t = rope_freqs + t * (HS / 2) + i * (x128::size / 2); + f128 freqs_reg = load128(rope_freqs_t); // caching the frequencies + + int idx = b * NH * T * HS + nh * T * HS + t * HS + i * x128::size; + x128 q_reg = load128cs(&q[idx]); + x128 k_reg = load128cs(&k[idx]); + x128 qout_reg, kout_reg; + for (int k = 0; k < x128::size / 2; k++) { // div by 2 because we're processing tuples of 2 + // rotate q + floatX x1 = q_reg[2*k]; + floatX x2 = q_reg[2*k + 1]; + floatX q_out1 = (floatX)((float)x1 * cosf(freqs_reg[k]) + (is_backward ? 1 : -1) * (float)x2 * sinf(freqs_reg[k])); + floatX q_out2 = (floatX)((float)x2 * cosf(freqs_reg[k]) + (is_backward ? -1 : 1) * (float)x1 * sinf(freqs_reg[k])); + qout_reg[2*k] = q_out1; + qout_reg[2*k + 1] = q_out2; + // rotate k + x1 = k_reg[2*k]; + x2 = k_reg[2*k + 1]; + floatX k_out1 = (floatX)((float)x1 * cosf(freqs_reg[k]) + (is_backward ? 1 : -1) * (float)x2 * sinf(freqs_reg[k])); + floatX k_out2 = (floatX)((float)x2 * cosf(freqs_reg[k]) + (is_backward ? -1 : 1) * (float)x1 * sinf(freqs_reg[k])); + kout_reg[2*k] = k_out1; + kout_reg[2*k + 1] = k_out2; + } + + store128cs(&q[idx], qout_reg); + store128cs(&k[idx], kout_reg); +} + // ---------------------------------------------------------------------------- // kernel launchers void attention_forward(floatX* out, floatX* qkvr, floatX* att, - floatX* inp, + floatX* inp, int use_rope, float* rope_freqs, int B, int T, int C, int NH, cudaStream_t stream) { NVTX_RANGE_FN(); // Note: `inp` is not needed for backward pass, so we re-use it as a scratch buffer. @@ -214,6 +258,13 @@ void attention_forward(floatX* out, floatX* qkvr, floatX* att, int num_blocks = CEIL_DIV(total_threads, block_size); permute_kernel<<>>(q, k, v, inp, B, T, NH, HS); + if (use_rope) { + assert(HS % x128::size == 0); + total_threads = B * NH * T * (HS / x128::size); + num_blocks = CEIL_DIV(total_threads, block_size); + rope_rotate_kernel<<>>(q, k, rope_freqs, B, NH, T, HS, 0); + } + floatX* preatt = inp; // reuse inp as scratch buffer matmul_cublaslt(preatt, k, q, nullptr, T, T, HS, stream, true, false, B * NH, T * HS, T * HS, T * T); @@ -239,6 +290,7 @@ void attention_forward(floatX* out, floatX* qkvr, floatX* att, void attention_backward(floatX* dinp, floatX* dqkvr, floatX* datt, floatX* scratch, const floatX* dout, const floatX* qkvr, const floatX* att, + int use_rope, float* rope_freqs, int B, int T, int C, int NH, cudaStream_t stream) { NVTX_RANGE_FN(); const int block_size = 256; @@ -269,6 +321,14 @@ void attention_backward(floatX* dinp, floatX* dqkvr, floatX* datt, floatX* scrat matmul_cublaslt(dq, k, dpreatt, nullptr, HS, T, T, stream, false, false, B * NH, T * HS, T * T, T * HS); // backward into k matmul_cublaslt(dk, q, dpreatt, nullptr, HS, T, T, stream, false, true, B * NH, T * HS, T * T, T * HS); + + if (use_rope) { + assert(HS % x128::size == 0); + int total_threads = B * NH * T * (HS / x128::size); + num_blocks = CEIL_DIV(total_threads, block_size); + rope_rotate_kernel<<>>(dq, dk, rope_freqs, B, NH, T, HS, 1); + } + // backward into inp num_blocks = CEIL_DIV(B * NH * T * HS, block_size); permute_kernel_backward<<>>(dinp, dq, dk, dv, B, T, NH, HS); diff --git a/llmc/encoder.cuh b/llmc/encoder.cuh index 3aa63e175..285555752 100644 --- a/llmc/encoder.cuh +++ b/llmc/encoder.cuh @@ -17,7 +17,7 @@ In the backward pass, the gradients flow to both, handled by different kernels // CUDA kernels __global__ void encoder_forward_kernel3(floatX* out, - const int* inp, const floatX* wte, const floatX* wpe, + const int* inp, const floatX* wte, const floatX* wpe, int use_rope, int B, int T, int C) { int idx = (blockIdx.x * blockDim.x + threadIdx.x) * x128::size; int N = B * T * C; @@ -36,9 +36,16 @@ __global__ void encoder_forward_kernel3(floatX* out, x128 packed_out; x128 wte128 = load128cs(wte_ix); - x128 wpe128 = load128cs(wpe_tc); + x128 wpe128; + if (!use_rope) { + wpe128 = load128cs(wpe_tc); + } for (int k = 0; k < x128::size; k++) { - packed_out[k] = (floatX)((float)wte128[k] + (float)wpe128[k]); + if (!use_rope) { + packed_out[k] = (floatX)((float)wte128[k] + (float)wpe128[k]); + } else { + packed_out[k] = wte128[k]; + } } store128(out_btc, packed_out); } @@ -151,17 +158,33 @@ __global__ void wpe_backward_kernel(floatX* dwpe, store128(dwpe_tc, packed_dwpe); } +__global__ void init_rope_freqs_kernel(float* rope_freqs, float rope_base_freq) { + int m = blockIdx.x; + int d_half = blockDim.x; + int i = threadIdx.x + 1; + int out_idx = m * d_half + i - 1; + + float theta_i = __powf(rope_base_freq, -2.0f * (float)(i - 1) / (2.f * (float)d_half)); + rope_freqs[out_idx] = (float)m * theta_i; +} + // ---------------------------------------------------------------------------- // kernel launchers +void init_rope_freqs(float* rope_freqs, int max_seq_len, int HS, float rope_base_freq, cudaStream_t stream) { + NVTX_RANGE_FN(); + init_rope_freqs_kernel<<>>(rope_freqs, rope_base_freq); + cudaCheck(cudaGetLastError()); +} + void encoder_forward(floatX* out, - const int* inp, const floatX* wte, const floatX* wpe, + const int* inp, const floatX* wte, const floatX* wpe, int use_rope, int B, int T, int C, cudaStream_t stream) { NVTX_RANGE_FN(); const int block_size = 256; const int N = B * T * C; const int grid_size = CEIL_DIV(N, (int)(block_size * x128::size)); - encoder_forward_kernel3<<>>(out, inp, wte, wpe, B, T, C); + encoder_forward_kernel3<<>>(out, inp, wte, wpe, use_rope, B, T, C); cudaCheck(cudaGetLastError()); } @@ -169,15 +192,17 @@ void encoder_forward(floatX* out, void encoder_backward(floatX* dwte, floatX* dwpe, floatX* scratch, // gpu outputs & scratch int* workload_indices, int4* bucket_info, // cpu scratch buffers const floatX* dout, const int* inp, const int* inputs_cpu, // cpu/gpu inputs - int B, int T, int C, unsigned int seed, cudaStream_t stream) { + int use_rope, int B, int T, int C, unsigned int seed, cudaStream_t stream) { NVTX_RANGE_FN(); - // Launch wpe kernel first (so it runs on the GPU in parallel with the CPU pre-processing for wte) - const int block_size = 256; - const int N = T * C / x128::size; - const int grid_size = CEIL_DIV(N, block_size); - wpe_backward_kernel<<>>(dwpe, dout, inp, B, T, C, seed); - cudaCheck(cudaGetLastError()); + if (!use_rope) { + // Launch wpe kernel first (so it runs on the GPU in parallel with the CPU pre-processing for wte) + const int block_size = 256; + const int N = T * C / x128::size; + const int grid_size = CEIL_DIV(N, block_size); + wpe_backward_kernel<<>>(dwpe, dout, inp, B, T, C, seed); + cudaCheck(cudaGetLastError()); + } // check the GPU scratch buffer is large enough to hold the bucket info and workload indices // todo - this is trivially true given hardcoded scratch buffer size here, is this useful? diff --git a/llmc/zero.cuh b/llmc/zero.cuh index e6c5b6e7c..fe2fc89cd 100644 --- a/llmc/zero.cuh +++ b/llmc/zero.cuh @@ -529,6 +529,7 @@ void multi_gpu_async_reduce_gradient( cudaCheck(cudaStreamWaitEvent(config->nccl_stream, config->compute_nccl_sync)); ncclCheck(ncclGroupStart()); // NCCL group: aggregate all pointers in a single NCCL GPU kernel. for (int i = 0; i < N; ++i) { + if (pointers[i] == NULL) continue; if(config->zero_stage == 0) { ncclCheck(ncclAllReduce( pointers[i], pointers[i], diff --git a/train_llama3.cu b/train_llama3.cu index 70d8d0c5a..d3ab818f3 100644 --- a/train_llama3.cu +++ b/train_llama3.cu @@ -319,6 +319,9 @@ typedef struct { // todo - if other functions need cpu scratch buffers in the future, reuse as generic scratch? int* workload_indices; // encoder_backward, B*T*num_c_groups (int) int4* bucket_info; // encoder_backward, B*T*num_c_groups (int4) - size for worst case + int use_rope; // use rope position encoding + float rope_base_freq; // base frequency for rope position encoding + float* rope_freqs; // rope position encoding frequencies } GPT2; void gpt2_init_common(GPT2 *model) { @@ -348,6 +351,10 @@ void gpt2_init_common(GPT2 *model) { model->init_state = true; model->recompute = 1; // good default: recompute gelu but not layernorm model->gelu_fusion = 0; //deviceProp.major >= 9 ? 2 : 0; // default: off for now (default must match main()) + // architecture specific settings + model->use_rope = 0; // use rope position encoding + model->rope_base_freq = 10000.0f; // base frequency for rope position encoding + model->rope_freqs = NULL; // rope position encoding frequencies } void gpt2_allocate_weights(GPT2 *model) { @@ -362,6 +369,15 @@ void gpt2_allocate_weights(GPT2 *model) { // create memory for model parameters on the device assert(model->params_memory == nullptr); model->params_memory = malloc_and_point_parameters(&model->params, model->param_elements, model->param_sizeof); + + // allocate memory for rope frequencies + if (model->use_rope) { + int HS = model->config.channels / model->config.num_heads; + assert(HS % 2 == 0); // HS must be even for RoPE + cudaCheck(cudaMalloc((float**)&model->rope_freqs, model->config.max_seq_len * (HS / 2) * sizeof(float))); + // TODO(gordicaleksa): would floatX mess up the rope frequencies due to a lower precision? + init_rope_freqs(model->rope_freqs, model->config.max_seq_len, HS / 2, model->rope_base_freq, main_stream); + } } void gpt2_allocate_state(GPT2 *model, int B, int T) { @@ -677,7 +693,7 @@ void gpt2_forward(GPT2 *model, const int* inputs, size_t B, size_t T) { // forward pass ParameterTensors params = model->params; // for brevity ActivationTensors acts = model->acts; - encoder_forward(acts.encoded, model->inputs, params.wte, params.wpe, B, T, C, main_stream); // encoding goes into residual[0] + encoder_forward(acts.encoded, model->inputs, params.wte, params.wpe, model->use_rope, B, T, C, main_stream); // encoding goes into residual[0] // first layernorm isn't fused layernorm_forward((model->recompute < 2) ? acts.ln1 : acts.lnf, acts.ln1_mean, acts.ln1_rstd, acts.encoded, params.ln1w, params.ln1b, B, T, C, main_stream); @@ -727,7 +743,7 @@ void gpt2_forward(GPT2 *model, const int* inputs, size_t B, size_t T) { // these are only needed as scratchpads for the forward pass, but // need not be stored for backward matmul_forward_cublaslt(scratch, l_ln1, l_qkvw, l_qkvb, B, T, C, 3*C, main_stream); - attention_forward(l_atty, l_qkvr, l_att, scratch, B, T, C, NH, main_stream); + attention_forward(l_atty, l_qkvr, l_att, scratch, model->use_rope, model->rope_freqs, B, T, C, NH, main_stream); #endif matmul_forward_cublaslt(scratch, l_atty, l_attprojw, l_attprojb, B, T, C, C, main_stream); @@ -915,7 +931,7 @@ void gpt2_backward_and_reduce(GPT2 *model, int* inputs, const int* targets, int // we need B x T x (4)C buffers. l_atty and l_fch aren't needed anymore at this point, so reuse their memory floatX* buffer_a = l_atty; floatX* buffer_b = l_fch_pre_gelu; // this is B x T x 4C, so even larger than what we need - attention_backward(dl_bt4c, buffer_b, scratchX, buffer_a, dl_btc, l_qkvr, l_att, B, T, C, NH, main_stream); + attention_backward(dl_bt4c, buffer_b, scratchX, buffer_a, dl_btc, l_qkvr, l_att, model->use_rope, model->rope_freqs, B, T, C, NH, main_stream); #endif if(model->recompute >= 2) { layernorm_forward(l_ln1, l_ln1_mean, l_ln1_rstd, residual, l_ln1w, l_ln1b, B, T, C, main_stream); @@ -947,7 +963,7 @@ void gpt2_backward_and_reduce(GPT2 *model, int* inputs, const int* targets, int } } encoder_backward(grads.wte, grads.wpe, scratchX, model->workload_indices, model->bucket_info, - dresidual, model->inputs, inputs, B, T, C, random_u32(&model->rng_state), main_stream); + dresidual, model->inputs, inputs, model->use_rope, B, T, C, random_u32(&model->rng_state), main_stream); // Aggregate all gradients that are not part of the transformer blocks if(last_step) { @@ -959,7 +975,7 @@ void gpt2_backward_and_reduce(GPT2 *model, int* inputs, const int* targets, int #endif cudaCheck(cudaMemcpyAsync(&model->mean_loss, model->accumulated_mean_loss, sizeof(float), cudaMemcpyDeviceToHost, main_stream)); // reduce the gradients for non-transformer block parameters - floatX* const pointers[] = {grads.wte, grads.wpe, grads.lnfw, grads.lnfb}; + floatX* const pointers[] = {grads.wte, model->use_rope ? NULL : grads.wpe, grads.lnfw, grads.lnfb}; const size_t nelem[] = {Vp * C, T * C, C, C}; multi_gpu_async_reduce_gradient(pointers, nelem, &multi_gpu_config, main_stream); } @@ -1004,6 +1020,10 @@ float gpt2_calculate_grad_norm(GPT2 *model, MultiGpuConfig* multi_gpu_config) { // grads_memory only contains the averaged gradients at the local shards, // so we only calculate the grad norm at the grads_memory belonging to the local shards for (int i = 0; i < NUM_PARAMETER_TENSORS; i++) { + if (model->use_rope && i == 1) { + // skip the wpe tensor if we are using RoPE -> minor optimization, results would be correct without this as well + continue; + } ShardInfo tensor = gpt2_get_tensor_at_layer(model, 0, i); ShardInfo shard = multi_gpu_get_shard_offset(tensor.size, multi_gpu_config, 1); ptrdiff_t offset = tensor.offset + shard.offset; @@ -1060,6 +1080,11 @@ void gpt2_update(GPT2 *model, float learning_rate, float beta1, float beta2, flo // AdamW update // handle adamw for all the transformer blocks for (int i = 0; i < NUM_PARAMETER_TENSORS; i++) { + if (model->use_rope && i == 1) { + // skip the wpe tensor if we are using RoPE + continue; + } + // generate a unique seed for each tensor unsigned int seed = random_u32(&model->rng_state); @@ -1404,6 +1429,8 @@ void error_usage() { // memory management fprintf(stderr, " -z zero_stage, Zero Optimization Stage, 0,1,2,3 (default = 0)\n"); fprintf(stderr, " -r recompute: less memory but less speed. (default = 1), 0|1|2 = none,gelu,gelu+ln\n"); + // architectural settings + fprintf(stderr, " -er enable RoPE positional embeddings? (default = 0)\n"); // multi-node settings fprintf(stderr, " -pn num_processes (default = 1)\n"); fprintf(stderr, " -pr process_rank (default = 0)\n"); @@ -1449,6 +1476,9 @@ int main(int argc, char *argv[]) { int recompute = 1; // recompute during backward setting, 0 = none, 1 = recompute gelu int zero_stage = 0; // Zero Optimization Stage for Multi-GPU training int hellaswag_eval = 0; + // architectural settings + int use_rope = 0; // use RoPE positional embeddings + float rope_base_freq = 10000.0f; // base frequency for RoPE // multi-node settings int num_processes = 1; // this should be set by the slurm environment int process_rank = 0; // this should be set by the slurm environment @@ -1463,7 +1493,7 @@ int main(int argc, char *argv[]) { // read in the args if (argv[i][1] == 'i') { train_data_pattern = argv[i+1]; } else if (argv[i][1] == 'j') { val_data_pattern = argv[i+1]; } - else if (argv[i][1] == 'e') { load_filename = argv[i+1]; } + else if (argv[i][1] == 'e' && argv[i][2] == '\0') { load_filename = argv[i+1]; } else if (argv[i][1] == 'o') { output_log_dir = argv[i+1]; } else if (argv[i][1] == 'n' && argv[i][2] == '\0') { checkpoint_every = atoi(argv[i+1]); } else if (argv[i][1] == 'y') { resume = atoi(argv[i+1]); } @@ -1498,6 +1528,7 @@ int main(int argc, char *argv[]) { else if (argv[i][1] == 's' && argv[i][2] == 'g') { skip_update_gradz = atof(argv[i+1]); } else if (argv[i][1] == 'n' && argv[i][2] == 'k') { checkpoints_keep = atoi(argv[i+1]); } else if (argv[i][1] == 'n' && argv[i][2] == 'm') { major_checkpoint_every = atoi(argv[i+1]); } + else if (argv[i][1] == 'e' && argv[i][2] == 'r') { use_rope = atoi(argv[i+1]); } else { error_usage(); } } @@ -1571,6 +1602,12 @@ int main(int argc, char *argv[]) { // build the GPT-2 model GPT2 model; gpt2_init_common(&model); + // architectural modifications + #ifdef ENABLE_CUDNN + use_rope = 0; // RoPE is not supported with cudnn atm + #endif + model.use_rope = use_rope; + model.rope_base_freq = rope_base_freq; if (resuming == 1) { // if `-y 1` was set, then we are resuming from the latest checkpoint // if we are using master weights, we'll init them later inside load_state()