From 09b47a747d16cd8cb59f75758daa8c5643c93a04 Mon Sep 17 00:00:00 2001 From: Andrej Karpathy Date: Fri, 13 Sep 2024 19:29:29 +0000 Subject: [PATCH 01/63] llama3 starting point is at gpt-2 exact copy paste for both train/test files --- test_llama3.cu | 395 ++++++++++ train_llama3.cu | 1904 +++++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 2299 insertions(+) create mode 100644 test_llama3.cu create mode 100644 train_llama3.cu diff --git a/test_llama3.cu b/test_llama3.cu new file mode 100644 index 000000000..e608ce229 --- /dev/null +++ b/test_llama3.cu @@ -0,0 +1,395 @@ +#define TESTING +#include "train_gpt2.cu" + +// poor man's tensor checker +int check_tensor(float *a, float *b, int n, const char* label, float threshold=1e-0) { + // a is the calculated tensor, b is the reference tensor + int print_upto = 10; + int ok = 1; + float max_diff = 0.0f; + float max_rel_error = 0.0f; + float max_to_threshold = 0.f; + float max_a = 0.0f; + float max_b = 0.0f; + float epsilon = 0.079; // BF16 epsilon value + printf("---\n"); + printf("checking tensor: %s\n", label); + for (int i = 0; i < n; i++) { + float t_eff = threshold + fabs(b[i]) * epsilon; + float diff = fabsf(a[i] - b[i]); + max_to_threshold = max(max_to_threshold, diff / t_eff); + if (diff > max_diff) { + max_diff = diff; + float denom = fabsf(b[i]); + max_rel_error = (denom == 0.0f) ? 0.0f : diff / denom; + max_a = a[i]; + max_b = b[i]; + } + if (diff > t_eff) { + ok = 0; + } + // print the first few elements so we can visually assess the "proof" of the comparison + if (i < print_upto) { + printf(diff <= t_eff ? "OK " : "NOT OK "); + printf("%f %f\n", a[i], b[i]); + } + } + // print the final result + if (ok) { + printf("TENSOR OK, max diff: %.3e, with rel error: %.3e (calculated=%10f, ref=%10f), %.2f%% of maximum error\n", + max_diff, max_rel_error, max_a, max_b, max_to_threshold*100); + } else { + printf("TENSOR NOT OK, max diff: %.3e, with rel error: %.3e (calculated=%10f, ref=%10f), %.2f%% of maximum error\n", + max_diff, max_rel_error, max_a, max_b, max_to_threshold*100); + } + return ok; +} + +// the same tensors as in the train file, but in float, which are used as reference +typedef struct { + float* wte; // (Vp, C) + float* wpe; // (maxT, C) + float* ln1w; // (L, C) + float* ln1b; // (L, C) + float* qkvw; // (L, 3*C, C) + float* qkvb; // (L, 3*C) + float* attprojw; // (L, C, C) + float* attprojb; // (L, C) + float* ln2w; // (L, C) + float* ln2b; // (L, C) + float* fcw; // (L, 4*C, C) + float* fcb; // (L, 4*C) + float* fcprojw; // (L, C, 4*C) + float* fcprojb; // (L, C) + float* lnfw; // (C) + float* lnfb; // (C) +} FloatParameterTensors; +static_assert(sizeof(FloatParameterTensors) == NUM_PARAMETER_TENSORS * sizeof(void*), "Inconsistent sizes!"); + +// malloc_and_point, but in float and on CPU, because we use this data to check correctness on CPU +float* float_cpu_malloc_and_point_parameters(FloatParameterTensors* params, size_t* param_sizes) { + // calculate the total number of parameters + size_t num_parameters = 0; + for (int i = 0; i < NUM_PARAMETER_TENSORS; i++) { + num_parameters += param_sizes[i]; + } + // everything is float so number of bytes to allocate is a simple multiplication + float* params_memory = (float*)mallocCheck(num_parameters * sizeof(float)); + float** ptrs[] = { + ¶ms->wte, ¶ms->wpe, ¶ms->ln1w, ¶ms->ln1b, ¶ms->qkvw, ¶ms->qkvb, + ¶ms->attprojw, ¶ms->attprojb, ¶ms->ln2w, ¶ms->ln2b, ¶ms->fcw, ¶ms->fcb, + ¶ms->fcprojw, ¶ms->fcprojb, ¶ms->lnfw, ¶ms->lnfb + }; + float* params_memory_iterator = params_memory; + for (int i = 0; i < NUM_PARAMETER_TENSORS; i++) { + *(ptrs[i]) = params_memory_iterator; + params_memory_iterator += param_sizes[i]; + } + return params_memory; +} + +int main(int argc, char *argv[]) { + char nccl_init_method[256] = "mpi"; // "tcp" or "fs" or "mpi" + int num_processes = -1; // doesn't matter when using MPI + int process_rank = -1; // doesn't matter when using MPI + int gpus_per_node = -1; // doesn't matter when using MPI + char server_ip[256] = ""; // doesn't matter when using MPI + char fs_path[256] = ""; // doesn't matter when using MPI + multi_gpu_config = multi_gpu_config_init(num_processes, process_rank, gpus_per_node, server_ip, fs_path, nccl_init_method); + common_start(false, true); + + // set the right paths + #if defined(ENABLE_BF16) + const char* load_filename = "gpt2_124M_bf16.bin"; + #else + const char* load_filename = "gpt2_124M.bin"; + #endif + + // build the GPT-2 model from a checkpoint + GPT2 model; + gpt2_init_common(&model); + + gpt2_build_from_checkpoint(&model, load_filename); + size_t V = model.config.vocab_size; + size_t Vp = model.config.padded_vocab_size; + size_t maxT = model.config.max_seq_len; + size_t L = model.config.num_layers; + size_t C = model.config.channels; + + for (int i = 1; i < argc; i+=2) { + if (i + 1 >= argc) { exit(EXIT_FAILURE); } // must have arg after flag + if (!(strlen(argv[i]) == 2 || strlen(argv[i]) == 3)) { exit(EXIT_FAILURE); } // must be -x[y] (one dash, one or two letters) + if (argv[i][0] != '-') { exit(EXIT_FAILURE); } // must start with dash + if (argv[i][1] == 'w') { model.use_master_weights = atoi(argv[i+1]); } + else if (argv[i][1] == 'r') { model.recompute = atoi(argv[i+1]); } + else if (argv[i][1] == 'g' && argv[i][2] == 'e') { model.gelu_fusion = atoi(argv[i+1]); } + } + + // load additional information that we will use for debugging and error checking + FILE *state_file = fopenCheck("gpt2_124M_debug_state.bin", "rb"); + int state_header[256]; + freadCheck(state_header, sizeof(int), 256, state_file); + if (state_header[0] != 20240327) { fprintf(stderr, "Bad magic state file\n"); exit(EXIT_FAILURE); } + if (state_header[1] != 2) { + fprintf(stderr, "Bad version in state file\n"); + fprintf(stderr, "---> HINT: try to re-run `python train_gpt2.py`\n"); + exit(EXIT_FAILURE); + } + int B = state_header[2]; // batch size, e.g. 4 + int T = state_header[3]; // time / sequence length (e.g. 64, up to maxT) + assert(0 <= T && T <= maxT); + printf("[State]\n"); + printf("batch_size: %d\n", B); + printf("seq_len: %d\n", T); + + set_zero_configs(&multi_gpu_config, 0, model.num_parameters); + + // read reference information from the file saved from Python/PyTorch side + // 1) input x and y + int* x = (int*)mallocCheck(B * T * sizeof(int)); + int* y = (int*)mallocCheck(B * T * sizeof(int)); + freadCheck(x, sizeof(int), B*T, state_file); + freadCheck(y, sizeof(int), B*T, state_file); + // 2) results of forward pass (logits and loss) + float* expected_logits = (float*) mallocCheck(B * T * V * sizeof(float)); + float* expected_loss = (float*) mallocCheck(1 * sizeof(float)); + freadCheck(expected_logits, sizeof(float), B*T*V, state_file); + freadCheck(expected_loss, sizeof(float), 1, state_file); + // 3) results of backward pass (parameter gradients) + FloatParameterTensors expected_grads; // will be read from file. right now: all in fp32 + float* expected_grads_memory = float_cpu_malloc_and_point_parameters(&expected_grads, model.param_elements); + freadCheck(expected_grads_memory, sizeof(float), model.num_parameters, state_file); + fcloseCheck(state_file); + + // this memory will be used to do one single copy of all (mixed precision) GPU grads to CPU grads + void* grads_memory_cpu = mallocCheck(model.num_parameters_bytes); + float* grads_memory_cpu_float = (float*)mallocCheck(model.num_parameters * sizeof(float)); + + // overall OK signal for the test + int allok = 1; + + gpt2_allocate_state(&model, B, T); + + // First, do target-free forward pass to validate logits + gpt2_forward(&model, x, B, T); + // at this point, target should be equal to expected_logits, let's compare + // copy logits to CPU so we can compare them + floatX* logits_cpu_raw = (floatX*)mallocCheck(B * T * Vp * sizeof(floatX)); + float* logits_cpu = (float*)mallocCheck(B * T * Vp * sizeof(float)); + cudaCheck(cudaMemcpy(logits_cpu_raw, model.acts.output, B * T * Vp * sizeof(floatX), cudaMemcpyDeviceToHost)); + for (int i = 0; i < B * T * Vp; i++) { + logits_cpu[i] = (float)logits_cpu_raw[i]; + } + + float logit_accuracy_threshold = 1e-3f; + float loss_diff_threshold = 1e-5f; + // FP16 and lower require very high tolerances unfortunately. TODO look into more + #if defined(ENABLE_BF16) || defined(ENABLE_F16) + logit_accuracy_threshold = 25.0f; // 15.0f was too low even without cuDNN?! :( + loss_diff_threshold = 0.05f; + #endif + + // compare the output logits from the forward pass + // also careful that we don't access and compare the padded columns of logits + int logits_ok = 1; + float max_diff = 0.0f; + for (int bt = 0; bt < B*T; bt++) { + for (int v = 0; v < V; v++) { + int i = bt * Vp + v; // linearized index + if (i < 10) { + printf("%f, %f\n", expected_logits[i], logits_cpu[i]); + } + float diff = fabsf(expected_logits[bt*V + v] - logits_cpu[i]); + max_diff = fmaxf(max_diff, diff); + if (diff >= logit_accuracy_threshold) { + printf("MISMATCH AT INDEX %d,%d: ", bt, v); + printf("%f %f\n", expected_logits[bt*V + v], logits_cpu[i]); + logits_ok = 0; + bt = B*T; // to break out of both loops + break; + } + } + } + allok = allok && logits_ok; + if(!logits_ok) { printf("NOT "); } + printf("OK (LOGITS)\n"); + printf("logit max diff: %f\n", max_diff); + + // let's do 10 training iterations, following the pytorch code + float losses[10]; + for (int step = 0; step < 10; step++) { + struct timespec start, end; + clock_gettime(CLOCK_MONOTONIC, &start); + gpt2_forward(&model, x, B, T); + gpt2_backward_and_reduce(&model, x, y, 1, 0); + clock_gettime(CLOCK_MONOTONIC, &end); + double time_elapsed_s = (end.tv_sec - start.tv_sec) + (end.tv_nsec - start.tv_nsec) / 1e9; + + if (step == 0) { + // error checking at step 0 for reference activations + + // move the (mixed precision) grads from GPU to CPU + cudaCheck(cudaMemcpy(grads_memory_cpu, model.grads_memory, model.num_parameters_bytes, cudaMemcpyDeviceToHost)); + + // convert all gradients to float on the CPU + char* src_iterator = (char*)grads_memory_cpu; // can be lower precision, so we use char* + float* dst_iterator = (float*)grads_memory_cpu_float; // float* + float* exp_iterator = expected_grads_memory; // float* of expected gradients from Python + float* tensors1[NUM_PARAMETER_TENSORS]; + float* tensors2[NUM_PARAMETER_TENSORS]; + for (int i = 0; i < NUM_PARAMETER_TENSORS; i++) { + if (model.param_sizeof[i] == sizeof(float)) { + // float tensor => copy over directly + memcpy(dst_iterator, src_iterator, model.param_elements[i] * sizeof(float)); + } else { + // low-precision tensor => convert to float + assert(model.param_sizeof[i] == sizeof(floatX)); // floatX is the single non-float supported atm + for (size_t j = 0; j < model.param_elements[i]; j++) { + dst_iterator[j] = ((floatX*)src_iterator)[j]; // convert to float + } + } + // for convenience record the position of comparison for reality vs. expectation + tensors1[i] = dst_iterator; // reality + tensors2[i] = exp_iterator; // expectation + // advance the iterators + src_iterator += model.param_elements[i] * model.param_sizeof[i]; + dst_iterator += model.param_elements[i]; + exp_iterator += model.param_elements[i]; + } + + // compare the gradients on the parameters all at once, in fp32 + // I set the tolerances manually by inspecting the gradient differences for + // a few elements of each tensor. bf16 looks ok but not amazing here. + // It's possible we have bugs lurking, or maybe it is bf16. Not 100% sure. + // Also, if code changes and some of these get tripped, it could be ok if it's not by too much, + // because our use of stochastic rounding is adding some non-determinism "pepper noise". + // In that case it's ok to extend the tolerance by a bit, after a manual review. + // Also, different GPUs may use different matrix multiplication algorithms, so the + // actual errors can be hardware specific. + + float grad_thresholds[NUM_PARAMETER_TENSORS] = {5e-1f, 4e-3f, 1e-1f, 3.5e-2f, 2e-2f, 3e-2f, 5e-2f, 5e-2f, 5e-2f, 1.5e-2f, 5e-4f, 8e-3f, 1.5e-3f, 2.5e-3f, 1e-1f, 2e-2f}; + #if defined(ENABLE_FP32) + for (int i = 0; i < NUM_PARAMETER_TENSORS; i++) { + grad_thresholds[i] = 1e-6f; // we can be much more precise in FP32 + } + #endif + + allok = allok & check_tensor(tensors1[0], tensors2[0], V * C, "wte", grad_thresholds[0]); + allok = allok & check_tensor(tensors1[1], tensors2[1], maxT * C, "wpe", grad_thresholds[1]); + allok = allok & check_tensor(tensors1[2], tensors2[2], L * 3*C * C, "qkvw", grad_thresholds[2]); + allok = allok & check_tensor(tensors1[3], tensors2[3], L * 3*C, "qkvb", grad_thresholds[3]); + allok = allok & check_tensor(tensors1[4], tensors2[4], L * C * C, "attprojw", grad_thresholds[4]); + allok = allok & check_tensor(tensors1[5], tensors2[5], L * C, "attprojb", grad_thresholds[5]); + allok = allok & check_tensor(tensors1[6], tensors2[6], L * 4*C * C, "fcw", grad_thresholds[6]); + allok = allok & check_tensor(tensors1[7], tensors2[7], L * 4*C, "fcb", grad_thresholds[7]); + allok = allok & check_tensor(tensors1[8], tensors2[8], L * C * 4*C, "fcprojw", grad_thresholds[8]); + allok = allok & check_tensor(tensors1[9], tensors2[9], L * C, "fcprojb", grad_thresholds[9]); + allok = allok & check_tensor(tensors1[10], tensors2[10], L * C, "ln1w", grad_thresholds[10]); + allok = allok & check_tensor(tensors1[11], tensors2[11], L * C, "ln1b", grad_thresholds[11]); + allok = allok & check_tensor(tensors1[12], tensors2[12], L * C, "ln2w", grad_thresholds[12]); + allok = allok & check_tensor(tensors1[13], tensors2[13], L * C, "ln2b", grad_thresholds[13]); + allok = allok & check_tensor(tensors1[14], tensors2[14], C, "lnfw", grad_thresholds[14]); + allok = allok & check_tensor(tensors1[15], tensors2[15], C, "lnfb", grad_thresholds[15]); + } + + float grad_norm = gpt2_calculate_grad_norm(&model, &multi_gpu_config); + float grad_scale = (grad_norm > 1.0f) ? 1.0f / grad_norm : 1.0f; + gpt2_update(&model, 1e-4f, 0.9f, 0.95f, 1e-8f, 0.0f, grad_scale, step+1, &multi_gpu_config); + + // print the timing information at the end + printf("step %d: loss %f (took %f ms)\n", step+1, model.mean_loss, time_elapsed_s * 1000); + // the expected losses from PyTorch were copied over after the print formatting rounded + // them to 6 decimal places, so we do the same here + float rounded_loss = roundf(model.mean_loss * 1000000) / 1000000; + losses[step] = rounded_loss; + } + + // expected losses are as follows, from Python + float expected_losses[10] = { + 5.270009f, + 4.060681f, + 3.320085f, + 2.717550f, + 2.181066f, + 1.653923f, + 1.168050f, + 0.736873f, + 0.401021f, + 0.187493f + }; + + // compare + for (int i = 0; i < 10; i++) { + if (fabsf(losses[i] - expected_losses[i]) >= loss_diff_threshold) { + printf("LOSS MISMATCH AT STEP %d: %f %f\n", i+1, losses[i], expected_losses[i]); + allok = 0; + } else { + printf("loss ok at step %d: %f %f\n", i+1, losses[i], expected_losses[i]); + } + } + + // Finally, let's check determinism + gpt2_write_to_checkpoint(&model, "test_gpt2cu_model.ckpt"); + + DataLoader loader; + dataloader_init(&loader, "dev/data/tinyshakespeare/tiny_shakespeare_val.bin", B, T, multi_gpu_config.process_rank, multi_gpu_config.num_processes, 1); + save_state("test_gpt2cu_state.ckpt", 10, &model, &loader); + int tokens[10]; + for (int step = 0; step < 10; step++) { + dataloader_next_batch(&loader); + gpt2_forward(&model, loader.inputs, B, T); + gpt2_backward_and_reduce(&model, loader.inputs, loader.targets, 1, 0); + gpt2_update(&model, 1e-4f, 0.9f, 0.95f, 1e-8f, 0.0f, 1.0f, step+11, &multi_gpu_config); + losses[step] = model.mean_loss; + tokens[step] = loader.inputs[0]; + } + + // reload + gpt2_free(&model); + gpt2_build_from_checkpoint(&model, "test_gpt2cu_model.ckpt"); + int ld_step; + gpt2_allocate_state(&model, B, T); + load_state(&ld_step, &model, &loader, "test_gpt2cu_state.ckpt"); + for (int step = 0; step < 10; step++) { + dataloader_next_batch(&loader); + gpt2_forward(&model, loader.inputs, B, T); + gpt2_backward_and_reduce(&model, loader.inputs, loader.targets, 1, 0); + gpt2_update(&model, 1e-4f, 0.9f, 0.95f, 1e-8f, 0.0f, 1.0f, step+11, &multi_gpu_config); + + if(loader.inputs[0] != tokens[step]) { + printf("Nondeterminism! Token mismatch at step %d: %d vs %d\n", step, tokens[step], loader.inputs[0]); + allok = false; + break; + } + + if(losses[step] != model.mean_loss) { + printf("Nondeterminism! Loss mismatch at step %d: %.15f vs %.15f\n", step, losses[step], model.mean_loss); + allok = false; + break; + } else { + printf("loss ok at step %d: %f %f\n", step, losses[step], model.mean_loss); + } + } + + // final approval + printf("overall okay: %d\n", allok); + + // delete intermediate test files + remove("test_gpt2cu_model.ckpt"); + remove("test_gpt2cu_state.ckpt"); + + // free everything + dataloader_free(&loader); + gpt2_free(&model); + common_free(model); + free(x); + free(y); + free(logits_cpu_raw); + free(logits_cpu); + free(expected_logits); + free(expected_loss); + free(expected_grads_memory); + free(grads_memory_cpu); + free(grads_memory_cpu_float); + return allok ? EXIT_SUCCESS : EXIT_FAILURE; +} diff --git a/train_llama3.cu b/train_llama3.cu new file mode 100644 index 000000000..70d8d0c5a --- /dev/null +++ b/train_llama3.cu @@ -0,0 +1,1904 @@ +/* +GPT-2 Transformer Neural Net training loop. See README.md for usage. +*/ +#include +#include +#include +#include +#include +#include +#include +#include +// ----------- CPU utilities ----------- +// defines: fopenCheck, freadCheck, fcloseCheck, fseekCheck, mallocCheck +// defines: create_dir_if_not_exists, find_max_step, ends_with_bin +#include "llmc/utils.h" +// defines: tokenizer_init, tokenizer_decode, tokenizer_free +#include "llmc/tokenizer.h" +// defines: dataloader_init, dataloader_reset, dataloader_next_batch, dataloader_free +// defines: evalloader_init, evalloader_reset, evalloader_next_batch, evalloader_free +#include "llmc/dataloader.h" +// defines: manual_seed, normal_ (same as torch.manual_seed and torch.normal) +#include "llmc/rand.h" +// defines: lr_scheduler_init, get_learning_rate +#include "llmc/schedulers.h" +// defines: sample_softmax, random_f32 +#include "llmc/sampler.h" +// defines: logger_init, logger_log_eval, logger_log_val, logger_log_train +#include "llmc/logger.h" +// defines: get_flops_promised +#include "llmc/mfu.h" +// defines: OutlierDetector, init_detector, update_detector +#include "llmc/outlier_detector.h" +// ----------- GPU utilities ----------- +// defines: +// WARP_SIZE, MAX_1024_THREADS_BLOCKS, CEIL_DIV, cudaCheck, PRECISION_MODE +// NVTX_RANGE_FN +#include "llmc/cuda_common.h" +// defines: +// Packed128, f128, x128 +// warpReduceSum, warpReduceMax, blockReduce, copy_and_cast_kernel, cudaMallocConditionallyManaged +#include "llmc/cuda_utils.cuh" +// defines: CUBLAS_LOWP, cublasCheck, cublaslt_workspace_size, cublaslt_workspace +// defines: cublas_compute, cublaslt_handle, cublas_handle +#include "llmc/cublas_common.h" +// ----------- Layer implementations in CUDA ----------- +// defines: encoder_forward, encoder_backward +#include "llmc/encoder.cuh" +// defines: layernorm_forward, residual_forward, fused_residual_forward5, layernorm_backward +#include "llmc/layernorm.cuh" +// defines: matmul_cublaslt, matmul_forward, matmul_backward, gelu_forward, gelu_backward_inplace +#include "llmc/matmul.cuh" +#ifdef ENABLE_CUDNN +// defines: create_cudnn, destroy_cudnn, attention_forward_cudnn, attention_backward_cudnn +#include "llmc/cudnn_att.h" +#else +// defines: attention_forward, attention_backward +#include "llmc/attention.cuh" +#endif +// defines: fused_classifier +#include "llmc/fused_classifier.cuh" +// defines: adamw_kernel3 +#include "llmc/adamw.cuh" +// defines: global_norm_squared +#include "llmc/global_norm.cuh" +// ----------- Multi-GPU support ----------- +// defines: ncclFloatX, ncclCheck, MultiGpuConfig, ShardInfo +// defines: printf0, multi_gpu_config +// defines: multi_gpu_config_init, multi_gpu_config_free +// defines: set_zero_configs, multi_gpu_cpu_float_sum, multi_gpu_barrier +// defines: multi_gpu_get_shard_offset, multi_gpu_async_reduce_gradient +#include "llmc/zero.cuh" + +// ---------------------------------------------------------------------------- +// global vars for I/O +char filename_buffer[512]; + +// ---------------------------------------------------------------------------- +// global vars containing information about the GPU this process is running on +cudaDeviceProp deviceProp; // fills in common_start() +cudaStream_t main_stream; +// buffer size to use for device <-> disk io +constexpr const size_t IO_BUF_SIZE = 32 * 1024 * 1024; + +// ---------------------------------------------------------------------------- +// GPT-2 model definition + +typedef struct { + int max_seq_len; // max sequence length, e.g. 1024 + int vocab_size; // vocab size, e.g. 50257 + int padded_vocab_size; // padded to e.g. %128==0, 50304 + int num_layers; // number of layers, e.g. 12 + int num_heads; // number of heads in attention, e.g. 12 + int channels; // number of channels, e.g. 768 +} GPT2Config; + +// the parameters of the model +constexpr const int NUM_PARAMETER_TENSORS = 16; +typedef struct { + floatX* wte; // (V, C) + floatX* wpe; // (maxT, C) + floatX* ln1w; // (L, C) + floatX* ln1b; // (L, C) + floatX* qkvw; // (L, 3*C, C) + floatX* qkvb; // (L, 3*C) + floatX* attprojw; // (L, C, C) + floatX* attprojb; // (L, C) + floatX* ln2w; // (L, C) + floatX* ln2b; // (L, C) + floatX* fcw; // (L, 4*C, C) + floatX* fcb; // (L, 4*C) + floatX* fcprojw; // (L, C, 4*C) + floatX* fcprojb; // (L, C) + floatX* lnfw; // (C) + floatX* lnfb; // (C) +} ParameterTensors; +static_assert(sizeof(ParameterTensors) == NUM_PARAMETER_TENSORS * sizeof(void*), "Inconsistent sizes!"); + +void fill_in_parameter_sizes(size_t* param_sizes, size_t* param_sizeof, GPT2Config config) { + size_t Vp = config.padded_vocab_size; + size_t C = config.channels; + size_t maxT = config.max_seq_len; + size_t L = config.num_layers; + param_sizes[0] = Vp * C; // wte + param_sizes[1] = maxT * C; // wpe + param_sizes[2] = L * C; // ln1w + param_sizes[3] = L * C; // ln1b + param_sizes[4] = L * (3 * C) * C; // qkvw + param_sizes[5] = L * (3 * C); // qkvb + param_sizes[6] = L * C * C; // attprojw + param_sizes[7] = L * C; // attprojb + param_sizes[8] = L * C; // ln2w + param_sizes[9] = L * C; // ln2b + param_sizes[10] = L * (4 * C) * C; // fcw + param_sizes[11] = L * (4 * C); // fcb + param_sizes[12] = L * C * (4 * C); // fcprojw + param_sizes[13] = L * C; // fcprojb + param_sizes[14] = C; // lnfw + param_sizes[15] = C; // lnfb + + // populate the parameter sizes in bytes (all the same for now, keeping for future use) + for (int i = 0; i < NUM_PARAMETER_TENSORS; i++) { + param_sizeof[i] = sizeof(floatX); + } +} + +// allocate memory for the parameters and point the individual tensors to the right places +void* malloc_and_point_parameters(ParameterTensors* params, size_t* param_elements, size_t *param_sizeof) { + // calculate the total number of parameters and bytes across all tensors + size_t num_parameters_bytes = 0; + for (int i = 0; i < NUM_PARAMETER_TENSORS; i++) { + num_parameters_bytes += param_elements[i] * param_sizeof[i]; + } + // malloc all parameters all at once on the device + void* params_memory; + cudaCheck(cudaMalloc((void**)¶ms_memory, num_parameters_bytes)); + // assign all the tensors their place in the array + floatX** ptrs[] = { + ¶ms->wte, ¶ms->wpe, ¶ms->ln1w, ¶ms->ln1b, ¶ms->qkvw, ¶ms->qkvb, + ¶ms->attprojw, ¶ms->attprojb, ¶ms->ln2w, ¶ms->ln2b, ¶ms->fcw, ¶ms->fcb, + ¶ms->fcprojw, ¶ms->fcprojb, ¶ms->lnfw, ¶ms->lnfb + }; + char* params_memory_iterator = (char*)params_memory; + for (int i = 0; i < NUM_PARAMETER_TENSORS; i++) { + *(ptrs[i]) = (floatX*)params_memory_iterator; + params_memory_iterator += param_elements[i] * param_sizeof[i]; + } + return params_memory; +} + +constexpr int NUM_ACTIVATION_TENSORS = 21; +typedef struct { + floatX* encoded; // (B, T, C) + floatX* ln1; // (L, B, T, C) + float* ln1_mean; // (L, B, T) + float* ln1_rstd; // (L, B, T) + floatX* atty; // (L, B, T, C) + // cuDNN saves only some statistics information +#if ENABLE_CUDNN + float* att; // (L, B, NH, T) +#else + floatX* att; // (L, B, NH, T, T) +#endif + + floatX* residual2; // (L, B, T, C) + floatX* ln2; // (L, B, T, C) + float* ln2_mean; // (L, B, T) + float* ln2_rstd; // (L, B, T) + floatX* fch; // (L, B, T, 4*C) + floatX* fch_gelu; // (L, B, T, 4*C) + floatX* residual3; // (L, B, T, C) + floatX* lnf; // (B, T, C); if LN recomputation is enabled (-r 2 and above), will be used for _all_ layernorms + float* lnf_mean; // (B, T) + float* lnf_rstd; // (B, T) + float* losses; // (B, T), will be accumulated in micro-steps + // adding these two compared to the CPU .c code, needed for attention kernel as buffers + floatX* qkvr; // (L, B, T, 3*C) + // in inference mode, this buffer will store the logits + // in training mode, this buffer will contain the *gradients* of the logits. + // during the processing of transformer blocks, we will also use this as a + // general scratchpad buffer. Allocation is made large enough to hold (B, T, 3C), + // (B, NH, T, T), and (B, T, V) shaped tensors. + floatX* output; + + // some additional scratch buffers + floatX* scratch_bt4c; // (B, T, 4*C) + floatX* scratch_btc; // (B, T, C) +} ActivationTensors; + + +struct TensorSpec { + void** ptr; + size_t size; + DType type; +}; + + +#define TENSOR_SPEC(pointer, size) TensorSpec{(void**)(&pointer), (size), dtype_of(pointer)}; + +void fill_in_activation_sizes(const ActivationTensors* data, TensorSpec (&tensors)[NUM_ACTIVATION_TENSORS], size_t B, size_t T, GPT2Config config, int recompute) { + size_t Vp = config.padded_vocab_size; + size_t L = config.num_layers; + size_t NH = config.num_heads; + size_t C = config.channels; + tensors[0] = TENSOR_SPEC(data->encoded, B * T * C); + // if recompute >= 1 then we will recompute the layernorm forward activation during backward pass + tensors[1] = TENSOR_SPEC(data->ln1, (recompute < 2) ? L * B * T * C : 0); + tensors[2] = TENSOR_SPEC(data->ln1_mean, L * B * T); + tensors[3] = TENSOR_SPEC(data->ln1_rstd, L * B * T); + tensors[4] = TENSOR_SPEC(data->atty, L * B * T * C); + #ifdef ENABLE_CUDNN + // FP32 stats tensor for cuDNN to be passed to backward pass + tensors[5] = TENSOR_SPEC(data->att, L * B * NH * T); + #else + tensors[5] = TENSOR_SPEC(data->att, L * B * NH * T * T); + #endif + tensors[6] = TENSOR_SPEC(data->residual2, L * B * T * C); + // if recompute >= 1 then we will recompute the layernorm forward activation during backward pass + tensors[7] = TENSOR_SPEC(data->ln2, (recompute < 2) ? L * B * T * C : 0); + tensors[8] = TENSOR_SPEC(data->ln2_mean, L * B * T); + tensors[9] = TENSOR_SPEC(data->ln2_rstd, L * B * T); + tensors[10] = TENSOR_SPEC(data->fch, L * B * T * 4*C); + // if recompute >= 1 then we will recompute gelu_forward during backward and use this as scratch buffer + tensors[11] = TENSOR_SPEC(data->fch_gelu, (recompute < 1) ? L * B * T * 4*C : B * T * 4*C); + tensors[12] = TENSOR_SPEC(data->residual3, L * B * T * C); + tensors[13] = TENSOR_SPEC(data->lnf, B * T * C); + tensors[14] = TENSOR_SPEC(data->lnf_mean, B * T); + tensors[15] = TENSOR_SPEC(data->lnf_rstd, B * T); + tensors[16] = TENSOR_SPEC(data->losses, B * T); + tensors[17] = TENSOR_SPEC(data->qkvr, L * B * T * 3*C); + tensors[18] = TENSOR_SPEC(data->output, B * T * max(3*C, max(NH*T, Vp))); + + tensors[19] = TENSOR_SPEC(data->scratch_bt4c, B * T * 4 * C); + tensors[20] = TENSOR_SPEC(data->scratch_btc, B * T * C); +} + +void* malloc_and_point_activations(TensorSpec (&tensors)[NUM_ACTIVATION_TENSORS]) { + size_t bytes = 0; + for (size_t i = 0; i < NUM_ACTIVATION_TENSORS; i++) { + bytes += tensors[i].size * sizeof_dtype(tensors[i].type); + } + + printf0("allocating %d MiB for activations\n", (int)round(bytes / (1024 * 1024))); + + void* acts_memory; + cudaCheck(cudaMalloc((void**)&acts_memory, bytes)); + + // cudaMalloc does not guarantee initial memory values so we memset the allocation here + // this matters because e.g. non-cuDNN attention assumes the attention buffer is zeroed + // todo - up to ~100ms on slow GPUs, could theoretically be more selective, but this is safer + cudaCheck(cudaMemset(acts_memory, 0, bytes)); + + char* acts_memory_iterator = (char*)acts_memory; + for (size_t i = 0; i < NUM_ACTIVATION_TENSORS; i++) { + // extra protection so we don't accidentally use an empty buffer + if(tensors[i].size == 0) { + *(tensors[i].ptr) = NULL; + }else { + *(tensors[i].ptr) = acts_memory_iterator; + acts_memory_iterator += tensors[i].size * sizeof_dtype(tensors[i].type); + } + } + return acts_memory; +} + +typedef struct { + GPT2Config config; + // the weights of the model, and their sizes + ParameterTensors params; + size_t param_elements[NUM_PARAMETER_TENSORS]; + size_t param_sizeof[NUM_PARAMETER_TENSORS]; + void* params_memory; + size_t num_parameters; + size_t num_parameters_bytes; + // gradients of the weights + ParameterTensors grads; + void* grads_memory; + // buffers for the AdamW optimizer + float* m_memory; + float* v_memory; + float* master_weights; // is NULL unless fp32 weights is enabled. + // the activations of the model, and their sizes + ActivationTensors acts; + TensorSpec acts_specs[NUM_ACTIVATION_TENSORS]; + void* acts_memory; + // other run state configuration + int batch_size; // the batch size (B) of current forward pass + int seq_len; // the sequence length (T) of current forward pass + int* inputs; // the input tokens for the current forward pass + int* targets; // the target tokens for the current forward pass + float mean_loss; // after the last backward micro-batch, will be populated with mean loss across all GPUs and micro-steps + float* accumulated_mean_loss; // GPU buffer used to accumulate loss across micro-steps + float* cpu_losses; // CPU buffer to copy the losses to, allocated with cudaMallocHost + unsigned long long rng_state; // the RNG state for seeding stochastic rounding etc. + unsigned long long rng_state_last_update; // RNG before last gpt2_update() to re-round identically from master weights + int use_master_weights; // keep master weights copy in float for optim update? 0|1 + bool init_state; // set to true if master weights need to be initialized + int gelu_fusion; // fuse gelu via cuBLASLt (0=none, 1=forward, 2=forward+backward) + int recompute; // recompute gelu | layernorm forward during model backward? 0|1|2 + // todo - if other functions need cpu scratch buffers in the future, reuse as generic scratch? + int* workload_indices; // encoder_backward, B*T*num_c_groups (int) + int4* bucket_info; // encoder_backward, B*T*num_c_groups (int4) - size for worst case +} GPT2; + +void gpt2_init_common(GPT2 *model) { + // common inits outside of the model weights + // memory lazily initialized in forward() + model->acts_memory = NULL; + model->inputs = NULL; + model->targets = NULL; + model->accumulated_mean_loss = NULL; + model->cpu_losses = NULL; + // the B,T params are determined and set, fixed on first batch in forward() + model->batch_size = 0; + model->seq_len = 0; + model->mean_loss = -1.0f; // -1.0f designates no loss, set at end of forward() + model->params_memory = NULL; + // memory lazily initialized in backward() + model->grads_memory = NULL; + model->workload_indices = NULL; // on cpu, for encoder_backward + model->bucket_info = NULL; // on cpu, for encoder_backward + // memory lazily initialized in update() + model->m_memory = NULL; + model->v_memory = NULL; + model->master_weights = NULL; + // other default settings + model->rng_state = 13371337 + multi_gpu_config.process_rank; // used in stochastic rounding + model->use_master_weights = 1; // safe default: do keep master weights in fp32 + model->init_state = true; + model->recompute = 1; // good default: recompute gelu but not layernorm + model->gelu_fusion = 0; //deviceProp.major >= 9 ? 2 : 0; // default: off for now (default must match main()) +} + +void gpt2_allocate_weights(GPT2 *model) { + // fill in all the parameter tensor dimensions and types + fill_in_parameter_sizes(model->param_elements, model->param_sizeof, model->config); + model->num_parameters = 0; + model->num_parameters_bytes = 0; + for (int i = 0; i < NUM_PARAMETER_TENSORS; i++) { + model->num_parameters += model->param_elements[i]; + model->num_parameters_bytes += model->param_elements[i] * model->param_sizeof[i]; + } + // create memory for model parameters on the device + assert(model->params_memory == nullptr); + model->params_memory = malloc_and_point_parameters(&model->params, model->param_elements, model->param_sizeof); +} + +void gpt2_allocate_state(GPT2 *model, int B, int T) { + printf0("allocating %d MiB for parameter gradients\n", (int)round(model->num_parameters * sizeof(floatX) / (1024 * 1024))); + assert(model->grads_memory == nullptr); + model->grads_memory = malloc_and_point_parameters(&model->grads, model->param_elements, model->param_sizeof); + + // record the current B,T as well + model->batch_size = B; + model->seq_len = T; + + // allocate the space + fill_in_activation_sizes(&model->acts, model->acts_specs, B, T, model->config, model->recompute); + model->acts_memory = malloc_and_point_activations(model->acts_specs); + // also create memory for caching inputs and targets + cudaCheck(cudaMalloc((void**)&model->inputs, B * T * sizeof(int))); + cudaCheck(cudaMalloc((void**)&model->targets, B * T * sizeof(int))); + cudaCheck(cudaMalloc(((void**)&model->accumulated_mean_loss), sizeof(float))); + cudaCheck(cudaMallocHost((void**)&model->cpu_losses, B * T * sizeof(float))); + + // initialise cpu scratch buffers for encoder backward + size_t num_c_groups = CEIL_DIV(model->config.channels, (WARP_SIZE * x128::size)); + assert((size_t)(model->batch_size * model->seq_len) * num_c_groups < (1ULL<<31ULL)); // todo - maybe an issue for llama3-400B(?) + model->workload_indices = (int*)mallocCheck(sizeof(int) * model->batch_size * model->seq_len * num_c_groups); + model->bucket_info = (int4*)mallocCheck(sizeof(int4) * model->batch_size * model->seq_len * num_c_groups); + + // cudaMallocConditionallyManaged can fall back to cudaMallocManaged if not enough memory on device + // and returns a status code of 1 if it had to fall back, in that case we want to print warning. + int memory_status = 0; + + // we will now init the optimizer states and master weights + // this is usually a substantial amount of memory allocation right here. + size_t shard_num_parameters = multi_gpu_config.shard_num_parameters; // num parameters we are responsible for + printf0("allocating %zu MiB for AdamW optimizer state m\n", (shard_num_parameters * sizeof(float)) >> 20); + printf0("allocating %zu MiB for AdamW optimizer state v\n", (shard_num_parameters * sizeof(float)) >> 20); + assert(model->m_memory == nullptr); + assert(model->v_memory == nullptr); + memory_status |= cudaMallocConditionallyManaged((void**)&model->m_memory, shard_num_parameters * sizeof(float)); + memory_status |= cudaMallocConditionallyManaged((void**)&model->v_memory, shard_num_parameters * sizeof(float)); + + if (model->use_master_weights == 1) { + assert(model->master_weights == nullptr); + printf0("allocating %zu MiB for master copy of params\n", (shard_num_parameters * sizeof(float)) >> 20); + memory_status |= cudaMallocConditionallyManaged((void**) &model->master_weights, shard_num_parameters * sizeof(float)); + } + + // report on mixed memory allocation status (re-using our float reduce function, bit awk ok) + int reduced_memory_status = (int) multi_gpu_cpu_float_sum((float)memory_status, &multi_gpu_config); + if (reduced_memory_status >= 1) { + printf0("WARNING: Fell back to cudaMallocManaged when initializing m,v,master_weights on %d GPUs\n", reduced_memory_status); + printf0(" Prevents an OOM, but code may run much slower due to device <-> host memory movement\n"); + } + // report on device memory usage + size_t free, total; + cudaCheck(cudaMemGetInfo(&free, &total)); + printf0("device memory usage: %zd MiB / %zd MiB\n", (total-free) / 1024 / 1024, total / 1024 / 1024); + // give an estimate of the maximum batch size + size_t bytes_per_sequence = 0; + for (size_t i = 0; i < NUM_ACTIVATION_TENSORS; i++) { + bytes_per_sequence += model->acts_specs[i].size * sizeof_dtype(model->acts_specs[i].type) / B; + } + printf0("memory per sequence: %zu MiB\n", bytes_per_sequence / 1024 / 1024); + printf0(" -> estimated maximum batch size: %zu\n", B + free / bytes_per_sequence); +} + +void gpt2_write_to_checkpoint(GPT2 *model, const char* checkpoint_path) { + // write the model to a checkpoint file + printf0("Writing model to %s\n", checkpoint_path); + FILE *model_file = fopenCheck(checkpoint_path, "wb"); + // write the header first + int model_header[256]; + memset(model_header, 0, sizeof(model_header)); + model_header[0] = 20240326; // magic number + assert(PRECISION_MODE == PRECISION_FP32 || PRECISION_MODE == PRECISION_BF16); + model_header[1] = PRECISION_MODE == PRECISION_FP32 ? 3 : 5; // version + model_header[2] = model->config.max_seq_len; + model_header[3] = model->config.vocab_size; + model_header[4] = model->config.num_layers; + model_header[5] = model->config.num_heads; + model_header[6] = model->config.channels; + model_header[7] = model->config.padded_vocab_size; + fwriteCheck(model_header, sizeof(int), 256, model_file); + // write the parameters + device_to_file(model_file, model->params_memory, model->num_parameters_bytes, + IO_BUF_SIZE, main_stream); + // close file, we're done + fcloseCheck(model_file); +} + +void gpt2_build_from_checkpoint(GPT2 *model, const char* checkpoint_path, bool weight_init=true) { + // If weight_init is true, we will load the weights from this checkpoint .bin file + // We sometimes want this to be false, if we are going to initialize these weights from + // the master weights that are instead stored in the state .bin file. + // In that case, this function mostly loads the model hyperparameters from the header. + + if (PRECISION_MODE == PRECISION_FP16) { + // TODO for later perhaps, would require us dynamically converting the + // model weights from fp32 to fp16 online, here in this function, or writing + // the fp16 weights directly from Python, which we only do for fp32/bf16 atm. + fprintf(stderr, "build_from_checkpoint() does not support fp16 right now.\n"); + exit(EXIT_FAILURE); + } + + // read in model from a checkpoint file + FILE *model_file = fopenCheck(checkpoint_path, "rb"); + int model_header[256]; + freadCheck(model_header, sizeof(int), 256, model_file); + if (model_header[0] != 20240326) { printf("Bad magic model file\n"); exit(EXIT_FAILURE); } + int version = model_header[1]; + if (!(version == 3 || version == 5)) { + // 3 = fp32, padded vocab + // 5 = bf16, padded vocab, layernorms also in bf16 + fprintf(stderr, "Bad version in model file\n"); + fprintf(stderr, "---> HINT: try to re-run `python train_gpt2.py`\n"); + exit(EXIT_FAILURE); + } + + // check if the precision mode of the checkpoing matches the model precision + if (weight_init) { + if (PRECISION_MODE == PRECISION_BF16 && version != 5) { + fprintf(stderr, "Precision is configured as BF16 but model at %s is not.\n", checkpoint_path); + fprintf(stderr, "---> HINT: are you sure you're loading a _bf16.bin file?\n"); + exit(EXIT_FAILURE); + } + if (PRECISION_MODE == PRECISION_FP32 && version != 3) { + fprintf(stderr, "Precision is configured as FP32 but model at %s is not.\n", checkpoint_path); + fprintf(stderr, "---> HINT: to turn on FP32 you have to compile like: `make train_gpt2cu PRECISION=FP32`\n"); + fprintf(stderr, "---> HINT: are you sure you're loading a .bin file without any _bf16 in the name?\n"); + exit(EXIT_FAILURE); + } + } + + // read in hyperparameters + model->config.max_seq_len = model_header[2]; + model->config.vocab_size = model_header[3]; + model->config.num_layers = model_header[4]; + model->config.num_heads = model_header[5]; + model->config.channels = model_header[6]; + model->config.padded_vocab_size = model_header[7]; + + // allocate memory for the model parameters + gpt2_allocate_weights(model); + + // read in the parameters if weight_init is true + if (weight_init) { + assert(model->params_memory != NULL); + file_to_device(model->params_memory, model_file, model->num_parameters_bytes, IO_BUF_SIZE, main_stream); + } + fcloseCheck(model_file); + + // only return from this function once we are certain the params are ready on the GPU + cudaCheck(cudaDeviceSynchronize()); +} + +void gpt2_set_hyperparameters(GPT2Config* config, const char* depth_str) { + int depth = atoi(depth_str); + assert(depth > 0); // atoi returns 0 if not a number + int channels, num_heads; + if (depth == 6) { channels = 384; num_heads = 6; } // (unofficial) gpt2-tiny (30M) + else if (depth == 12) { channels = 768; num_heads = 12; } // gpt2 (124M) + else if (depth == 24) { channels = 1024; num_heads = 16; } // gpt2-medium (350M) + else if (depth == 36) { channels = 1280; num_heads = 20; } // gpt2-large (774M) + else if (depth == 48) { channels = 1600; num_heads = 25; } // gpt2-xl (1558M) + else if (depth == 60) { channels = 1920; num_heads = 30; } // (unofficial) 2.7B + else if (depth == 72) { channels = 2880; num_heads = 30; } // (unofficial) 7.3B + else if (depth == 84) { channels = 3456; num_heads = 36; } // (unofficial) 12.2B + else { fprintf(stderr, "Unsupported GPT-2 depth: %d\n", depth); exit(EXIT_FAILURE); } + config->num_layers = depth; + config->channels = channels; + config->num_heads = num_heads; + config->max_seq_len = 1024; +} + +void gpt3_set_hyperparameters(GPT2Config* config, const char* channels_str) { + // we use channels instead of depth for GPT-3 because GPT-3 model depths are not one-to-one + // note that our models are not necessarily identical to GPT-3 because + // we use dense attention, not the alternating dense/banded attention of GPT-3 + int channels = atoi(channels_str); + assert(channels > 0); // atoi returns 0 if not a number + int depth, head_size; + if (channels == 384) { depth = 6; head_size = 64; } // (unofficial) gpt3-tiny (31M) + else if (channels == 768) { depth = 12; head_size = 64; } // gpt3-small (125M) + else if (channels == 1024) { depth = 24; head_size = 64; } // gpt3-medium (350M) + else if (channels == 1536) { depth = 24; head_size = 96; } // gpt3-large (760M) + else if (channels == 2048) { depth = 24; head_size = 128; } // gpt3-xl (1.3B) [heads fixed] + else if (channels == 2560) { depth = 32; head_size = 80; } // gpt3-2.7B + else if (channels == 4096) { depth = 32; head_size = 128; } // gpt3-6.7B + else if (channels == 5140) { depth = 40; head_size = 128; } // gpt3-13B + else if (channels == 12288) { depth = 96; head_size = 128; } // gpt3 (175B) + else { fprintf(stderr, "Unsupported GPT-3 channels: %d\n", channels); exit(EXIT_FAILURE); } + assert(channels % head_size == 0); + config->num_layers = depth; + config->channels = channels; + config->num_heads = channels / head_size; + config->max_seq_len = 2048; // NOTE: GPT-3 uses context length of 2048 tokens, up from 1024 in GPT-2 +} + +void gpt_build_from_descriptor(GPT2 *model, const char* descriptor) { + // The model descriptor can be: + // - legacy format "dX", where X is number, e.g. "d12". This creates GPT-2 model with 12 layers. + // - new explicit format "gpt2:dX", same as above, e.g. "gpt2:d48" for GPT-2 with 48 layers. + // - "gpt3:cX", where X is now the channel count, e.g. "gpt3:c768" is the smallest GPT-3 model. + + // check the valid prexies and dispatch to the right setup function + assert(descriptor != NULL); + size_t len = strlen(descriptor); + if (len > 1 && descriptor[0] == 'd') { + gpt2_set_hyperparameters(&model->config, descriptor + 1); // pass along the depth str without the 'd' + } else if (len > 6 && strncmp(descriptor, "gpt2:d", 6) == 0) { + gpt2_set_hyperparameters(&model->config, descriptor + 6); // pass along the depth str without the 'gpt2:d' + } else if (len > 6 && strncmp(descriptor, "gpt3:c", 6) == 0) { + gpt3_set_hyperparameters(&model->config, descriptor + 6); // pass along the channels str without the 'gpt3:c' + } else { + fprintf(stderr, "Unsupported model descriptor: %s\n", descriptor); exit(EXIT_FAILURE); + } + + // both GPT-2 and GPT-3 use the same tokenizer with 50257 tokens + model->config.vocab_size = 50257; + model->config.padded_vocab_size = 50304; // padded to 128 for CUDA kernel efficiency + + gpt2_allocate_weights(model); + + // allocate and random init the memory for all the parameters with GPT-2 schema + // weights ~N(0, 0.02), biases 0, c_proj weights ~N(0, 0.02/(2*L)**0.5) + // NOTE: assuming all parameters are of the type floatX, could be relaxed later + mt19937_state init_rng; + manual_seed(&init_rng, 42); + floatX* params_memory_cpu = (floatX*)mallocCheck(model->num_parameters_bytes); + memset(params_memory_cpu, 0, model->num_parameters_bytes); + // fill in all the weights with random values + float residual_scale = 1.0f / sqrtf(2.0f * model->config.num_layers); + // we have to init all these tensors exactly in the order that PyTorch initializes them + // so that we can match them up and get correctness and exactly the same initial conditions + size_t L = model->config.num_layers; + size_t offset = 0; + for (int l = 0; l < L; l++) { + offset = 0; + for (int i = 0; i < NUM_PARAMETER_TENSORS; i++) { + // the layernorm parameters are all initialized to 1 + if (l == 0 && (i == 2 || i == 8 || i == 14)) { // only at l = 0 to init these just once + for (size_t j = 0; j < model->param_elements[i]; j++) { + params_memory_cpu[offset + j] = 1.0f; + } + } + // weights tensors are handled here + if ((l == 0 && (i == 0 || i == 1)) // only at l = 0, init the wte and wpe tensors + || i == 4 || i == 6 || i == 10 || i == 12) { + size_t n = model->param_elements[i]; + size_t layer_offset = 0; + if (i == 0) { + // for wte tensor (padded vocab) override to init V instead of Vp rows + n = model->config.vocab_size * model->config.channels; + } + if (i == 4 || i == 6 || i == 10 || i == 12) { + // weight tensors, we are only initializing layer l + assert(n % L == 0); + n = n / L; + layer_offset = l * n; + } + // in GPT-2, the projections back into the residual stream are additionally + // scaled by 1/sqrt(2*L) for training stability + float scale = (i == 6 || i == 12) ? 0.02f * residual_scale : 0.02f; + // okay let's draw the random numbers and write them + float *fp32_buffer = (float*)mallocCheck(n * sizeof(float)); + normal_(fp32_buffer, n, 0.0f, scale, &init_rng); + for (size_t j = 0; j < n; j++) { + params_memory_cpu[offset + layer_offset + j] = (floatX)fp32_buffer[j]; + } + free(fp32_buffer); + } + offset += model->param_elements[i]; + } + } + + // copy them to GPU + cudaCheck(cudaMemcpy(model->params_memory, params_memory_cpu, model->num_parameters_bytes, cudaMemcpyHostToDevice)); + free(params_memory_cpu); +} + +// propagate inputs through the network to produce logits. +// right now, this function is fully synchronous with the host +void gpt2_forward(GPT2 *model, const int* inputs, size_t B, size_t T) { + NVTX_RANGE_FN(); + // we must be careful and use size_t instead of int, otherwise + // we could overflow int. E.g. l * B * NH * T * T overflows int at B 16. + + // ensure the model was initialized or error out + if (model->params_memory == NULL) { + printf("Error: model was not initialized properly.\n"); + exit(EXIT_FAILURE); + } + + // convenience parameters + const size_t V = model->config.vocab_size; + const size_t Vp = model->config.padded_vocab_size; + const size_t L = model->config.num_layers; + const size_t NH = model->config.num_heads; + const size_t C = model->config.channels; + + // validate B,T are not larger than the values used at initialisation + // (smaller B,T are okay for inference only) + if (B > model->batch_size || T > model->seq_len) { + printf("Model: B=%d T=%d, Desired: B=%d T=%d\n", model->batch_size, model->seq_len, (int)B, (int)T); + exit(EXIT_FAILURE); + } + + // copy inputs/targets to the model + cudaCheck(cudaMemcpy(model->inputs, inputs, B * T * sizeof(int), cudaMemcpyHostToDevice)); + // validate inputs, all indices must be in the range [0, V) + // we can do this while the copies are already underway + tokenCheck(inputs, B*T, V); + + // forward pass + ParameterTensors params = model->params; // for brevity + ActivationTensors acts = model->acts; + encoder_forward(acts.encoded, model->inputs, params.wte, params.wpe, B, T, C, main_stream); // encoding goes into residual[0] + + // first layernorm isn't fused + layernorm_forward((model->recompute < 2) ? acts.ln1 : acts.lnf, acts.ln1_mean, acts.ln1_rstd, acts.encoded, params.ln1w, params.ln1b, B, T, C, main_stream); + + for (int l = 0; l < L; l++) { + NvtxRange layer_range("Layer", l); + + floatX* residual = l == 0 ? acts.encoded : acts.residual3 + (l-1) * B * T * C; + + // get the pointers of the weights for this layer + floatX* l_qkvw = params.qkvw + l * 3*C * C; + floatX* l_qkvb = params.qkvb + l * 3*C; + floatX* l_attprojw = params.attprojw + l * C * C; + floatX* l_attprojb = params.attprojb + l * C; + floatX* l_ln2w = params.ln2w + l * C; + floatX* l_ln2b = params.ln2b + l * C; + floatX* l_fcw = params.fcw + l * 4*C * C; + floatX* l_fcb = params.fcb + l * 4*C; + floatX* l_fcprojw = params.fcprojw + l * C * 4*C; + floatX* l_fcprojb = params.fcprojb + l * C; + + // get the pointers of the activations for this layer + floatX* l_ln1 = (model->recompute < 2) ? acts.ln1 + l * B * T * C : acts.lnf; + floatX* l_qkvr = acts.qkvr + l * B * T * 3*C; + floatX* l_atty = acts.atty + l * B * T * C; + floatX* l_residual2 = acts.residual2 + l * B * T * C; + floatX* l_ln2 = (model->recompute < 2) ? acts.ln2 + l * B * T * C : acts.lnf; + float* l_ln2_mean = acts.ln2_mean + l * B * T; + float* l_ln2_rstd = acts.ln2_rstd + l * B * T; + floatX* l_fch = acts.fch + l * B * T * 4*C; + // reuse the same activation buffer at each layer, as we'll re-compute the gelu during backward + // very useful because we dramatically reduce VRAM usage, and may be able to fit larger batch size + floatX* l_fch_gelu = (model->recompute < 1) ? acts.fch_gelu + l * B * T * 4*C : acts.fch_gelu; + floatX* l_residual3 = acts.residual3 + l * B * T * C; + floatX* scratch = (floatX*)acts.output; // used for non-cudnn attention, fcproj, attproj, etc. + + // now do the forward pass + #ifdef ENABLE_CUDNN + float* l_att = (float*)acts.att + l * B * NH * T; // cuDNN needs a smaller FP32 tensor + matmul_forward_cublaslt(l_qkvr, l_ln1, l_qkvw, l_qkvb, B, T, C, 3*C, main_stream); + attention_forward_cudnn(l_atty, (float*)l_att, l_qkvr, B, T, NH, C, main_stream); + #else + floatX* l_att = acts.att + l * B * NH * T * T; + if (T != model->seq_len) { // unused parts of attention buffer must be zeroed (T-dependent) + cudaCheck(cudaMemset(l_att, 0, B * NH * T * T * sizeof(floatX))); + } + // these are only needed as scratchpads for the forward pass, but + // need not be stored for backward + matmul_forward_cublaslt(scratch, l_ln1, l_qkvw, l_qkvb, B, T, C, 3*C, main_stream); + attention_forward(l_atty, l_qkvr, l_att, scratch, B, T, C, NH, main_stream); + #endif + + matmul_forward_cublaslt(scratch, l_atty, l_attprojw, l_attprojb, B, T, C, C, main_stream); + fused_residual_forward5(l_residual2, l_ln2, l_ln2_mean, l_ln2_rstd, residual, scratch, l_ln2w, l_ln2b, B*T, C, main_stream); + matmul_forward_cublaslt(l_fch_gelu, l_ln2, l_fcw, l_fcb, B, T, C, 4*C, main_stream, l_fch, model->gelu_fusion); + matmul_forward_cublaslt(scratch, l_fch_gelu, l_fcprojw, l_fcprojb, B, T, 4*C, C, main_stream); + // OK, fusion across blocks. + if(l+1 != L) { + floatX* l_ln1 = (model->recompute < 2) ? acts.ln1 + (l + 1) * B * T * C : acts.lnf; + float* l_ln1_mean = acts.ln1_mean + (l + 1) * B * T; + float* l_ln1_rstd = acts.ln1_rstd + (l + 1) * B * T; + const floatX* l_ln1w = params.ln1w + (l + 1) * C; + const floatX* l_ln1b = params.ln1b + (l + 1) * C; + fused_residual_forward5(l_residual3, l_ln1, l_ln1_mean, l_ln1_rstd, l_residual2, scratch, l_ln1w, l_ln1b, + B * T, C, main_stream); + } else { + fused_residual_forward5(l_residual3, acts.lnf, acts.lnf_mean, acts.lnf_rstd, l_residual2, scratch, + params.lnfw, params.lnfb, + B * T, C, main_stream); + } + } + + matmul_forward_cublaslt(acts.output, acts.lnf, params.wte, NULL, B, T, C, Vp, main_stream); + cudaCheck(cudaDeviceSynchronize()); +} + + +// Forwards both the model and the loss and is used for validation splits and evals. +// In particular it populates cpu_losses with loss at each token. +// Some of the evals (e.g. HellaSwag) require the per-token losses, which are produced here. +float gpt2_validate(GPT2 *model, const int* inputs, const int* targets, size_t B, size_t T) { + assert(targets != NULL); + // forward the model itself + gpt2_forward(model, inputs, B, T); + // convenience shortcuts, size_t instead of int so that pointer arithmetics don't overflow + const size_t V = model->config.vocab_size; + const size_t Vp = model->config.padded_vocab_size; + + NvtxRange classifier_and_loss_range("classifier_and_loss"); + ActivationTensors acts = model->acts; + float mean_loss = 0.0f; + // fused classifier: does the forward pass and first part of the backward pass + const float dloss = 1.0f / (B * T); // results in the uniform average loss over all elements + // note: we don't need to generate dlogits here + cudaCheck(cudaMemset(acts.losses, 0, B*T*sizeof(float))); + cudaCheck(cudaMemcpy(model->targets, targets, B * T * sizeof(int), cudaMemcpyHostToDevice)); + tokenCheck(targets, B*T, V); // while the memcpy is underway, validate the targets + fused_classifier(acts.output, acts.losses, dloss, model->targets, B, T, V, Vp, False, main_stream); + cudaCheck(cudaMemcpy(model->cpu_losses, acts.losses, B * T * sizeof(float), cudaMemcpyDeviceToHost)); + for (int i = 0; i < B*T; i++) { + mean_loss += model->cpu_losses[i]; + } + mean_loss /= B*T; + cudaCheck(cudaDeviceSynchronize()); + return mean_loss; +} + +void gpt2_backward_and_reduce(GPT2 *model, int* inputs, const int* targets, int grad_accum_steps, int micro_step) { + if(model->grads_memory == nullptr) { + fprintf(stderr, "Need to allocate gradients before backward"); + exit(EXIT_FAILURE); + } + NVTX_RANGE_FN(); + bool last_step = micro_step == grad_accum_steps - 1; + // on the first micro-step zero the gradients, as we're about to += accumulate into them + if (micro_step == 0) { + // there are currently two state vars during the gradient accumulation inner loop: + // 1) the losses accumulate += into acts.losses, reset here + // 2) the gradients accumulate += into grads_memory, reset here + cudaCheck(cudaMemsetAsync(model->acts.losses, 0, model->batch_size * model->seq_len * sizeof(float), main_stream)); + cudaCheck(cudaMemsetAsync(model->grads_memory, 0, model->num_parameters * sizeof(floatX), main_stream)); + } + + // convenience shortcuts, size_t instead of int so that pointer arithmetics don't overflow + const size_t B = model->batch_size; + const size_t T = model->seq_len; + const size_t V = model->config.vocab_size; + const size_t Vp = model->config.padded_vocab_size; + const size_t L = model->config.num_layers; + const size_t NH = model->config.num_heads; + const size_t C = model->config.channels; + + ParameterTensors params = model->params; // for brevity + ParameterTensors grads = model->grads; + ActivationTensors acts = model->acts; + + // accumulate the losses inside acts.losses, and kick off the backward pass inside the fused classifier + NvtxRange classifier_and_loss_range("classifier_and_loss"); + const float dloss = 1.0f / (float)(B * T * grad_accum_steps); // results in the uniform average loss over all elements + cudaCheck(cudaMemcpy(model->targets, targets, B * T * sizeof(int), cudaMemcpyHostToDevice)); + tokenCheck(targets, B*T, V); + fused_classifier(acts.output, acts.losses, dloss, model->targets, B, T, V, Vp, True, main_stream); + + // backward pass: go in the reverse order of the forward pass, and call backward() functions + + // reset residual stream gradients (put here to work with gradient accumulation) + floatX* dresidual = (floatX*)model->acts.scratch_btc; // the main buffer holding the gradient in the backward pass + cudaCheck(cudaMemset(dresidual, 0, B * T * C * sizeof(floatX))); + + // re-use the output buffer of the forward pass as a scratchpad during backward pass + float* scratchF = (float*)acts.output; + floatX* scratchX = (floatX*)acts.output; + + // we kick off the chain rule by filling in dlosses with 1.0f/(B*T) + // this was done in the fused classifier kernel as last step of forward pass + // technically that is a small, inline backward() pass of calculating + // total, final loss as the mean over all losses over all (B,T) positions in the batch + // next: backward the classifier matmul + matmul_backward(model->acts.scratch_bt4c, grads.wte, NULL, acts.output, acts.lnf, params.wte, NULL, B, T, C, Vp, main_stream); + // backward the final layernorm + floatX* residual = acts.residual3 + (L-1) * B * T * C; // last residual is in residual3 + layernorm_backward(dresidual, grads.lnfw, grads.lnfb, scratchF, model->acts.scratch_bt4c, residual, params.lnfw, acts.lnf_mean, acts.lnf_rstd, B, T, C, main_stream); + + // from this point on, we no longer need the values stored in the last residual, so we can reuse that memory as generic + // scratch for backward computations + floatX* dl_btc = residual; + + // now backward all the layers + for (int l = L-1; l >= 0; l--) { + NvtxRange layer_range("Layer", l); + + residual = l == 0 ? acts.encoded : acts.residual3 + (l-1) * B * T * C; + + // get the pointers of the weights for this layer + floatX* l_ln1w = params.ln1w + l * C; + floatX* l_ln1b = params.ln1b + l * C; + floatX* l_qkvw = params.qkvw + l * 3*C * C; + floatX* l_attprojw = params.attprojw + l * C * C; + floatX* l_ln2w = params.ln2w + l * C; + floatX* l_ln2b = params.ln2b + l * C; + floatX* l_fcw = params.fcw + l * 4*C * C; + floatX* l_fcprojw = params.fcprojw + l * C * 4*C; + // get the pointers of the gradients of the weights for this layer + floatX* dl_ln1w = grads.ln1w + l * C; + floatX* dl_ln1b = grads.ln1b + l * C; + floatX* dl_qkvw = grads.qkvw + l * 3*C * C; + floatX* dl_qkvb = grads.qkvb + l * 3*C; + floatX* dl_attprojw = grads.attprojw + l * C * C; + floatX* dl_attprojb = grads.attprojb + l * C; + floatX* dl_ln2w = grads.ln2w + l * C; + floatX* dl_ln2b = grads.ln2b + l * C; + floatX* dl_fcw = grads.fcw + l * 4*C * C; + floatX* dl_fcb = grads.fcb + l * 4*C; + floatX* dl_fcprojw = grads.fcprojw + l * C * 4*C; + floatX* dl_fcprojb = grads.fcprojb + l * C; + // get the pointers of the activations for this layer + floatX* l_ln1 = (model->recompute < 2) ? acts.ln1 + l * B * T * C : acts.lnf; + float* l_ln1_mean = acts.ln1_mean + l * B * T; + float* l_ln1_rstd = acts.ln1_rstd + l * B * T; + floatX* l_qkvr = acts.qkvr + l * B * T * 3*C; + floatX* l_atty = acts.atty + l * B * T * C; + floatX* l_residual2 = acts.residual2 + l * B * T * C; + floatX* l_ln2 = (model->recompute < 2) ? acts.ln2 + l * B * T * C : acts.lnf; + float* l_ln2_mean = acts.ln2_mean + l * B * T; + float* l_ln2_rstd = acts.ln2_rstd + l * B * T; + floatX* l_fch_pre_gelu = acts.fch + l * B * T * 4*C; + floatX* l_fch_gelu = (model->recompute < 1) ? acts.fch_gelu + l * B * T * 4*C : acts.fch_gelu; + // get the pointers of the gradients of the activations for this layer + // notice that there is no l *, because we just have a single copy, and keep + // re-using this memory in every Transformer block as we calculate backward pass + + floatX* dl_bt4c = (floatX*)model->acts.scratch_bt4c; + + // start the backward pass for this layer + if(model->recompute >= 1) { + // recompute >= 1 means we recompute gelu. in this case, + // l_fch_gelu is just a buffer, so re-compute the gelu from l_fch here + gelu_forward(l_fch_gelu, l_fch_pre_gelu, B*T*4*C, main_stream); + } + matmul_backward(dl_bt4c, dl_fcprojw, dl_fcprojb, dresidual, l_fch_gelu, l_fcprojw, scratchF, B, T, 4*C, C, main_stream, l_fch_pre_gelu, model->gelu_fusion); + if(model->recompute >= 2) { + // same as gelu above, l_ln1 and l_ln2 are just buffers if recompute >= 2, recompute them here on demand + layernorm_forward(l_ln2, l_ln2_mean, l_ln2_rstd, l_residual2, l_ln2w, l_ln2b, B, T, C, main_stream); + } + matmul_backward(dl_btc, dl_fcw, dl_fcb, dl_bt4c, l_ln2, l_fcw, scratchF, B, T, C, 4 * C, main_stream); + // layernorm backward does += to the dresidual, so it correctly accumulates grad from the MLP block above + layernorm_backward(dresidual, dl_ln2w, dl_ln2b, scratchF, dl_btc, l_residual2, l_ln2w, l_ln2_mean, l_ln2_rstd, B, T, C, main_stream); + matmul_backward(dl_btc, dl_attprojw, dl_attprojb, dresidual, l_atty, l_attprojw, scratchF, B, T, C, C, main_stream); + + #ifdef ENABLE_CUDNN + float* l_att = (float*)acts.att + l * B * NH * T; // cuDNN needs a smaller FP32 tensor + attention_backward_cudnn(dl_bt4c, dl_btc, l_qkvr, l_atty, (float*)l_att, B, T, NH, C, main_stream); + #else + floatX* l_att = acts.att + l * B * NH * T * T; + // we need B x T x (4)C buffers. l_atty and l_fch aren't needed anymore at this point, so reuse their memory + floatX* buffer_a = l_atty; + floatX* buffer_b = l_fch_pre_gelu; // this is B x T x 4C, so even larger than what we need + attention_backward(dl_bt4c, buffer_b, scratchX, buffer_a, dl_btc, l_qkvr, l_att, B, T, C, NH, main_stream); + #endif + if(model->recompute >= 2) { + layernorm_forward(l_ln1, l_ln1_mean, l_ln1_rstd, residual, l_ln1w, l_ln1b, B, T, C, main_stream); + } + // QKV parameter gradients + matmul_backward(dl_btc, dl_qkvw, dl_qkvb, dl_bt4c, l_ln1, l_qkvw, scratchF, B, T, C, 3 * C, main_stream); + // layernorm backward does += to dresidual, so it correctly accumulates gradient for the Attention block above + layernorm_backward(dresidual, dl_ln1w, dl_ln1b, scratchF, dl_btc, residual, l_ln1w, l_ln1_mean, l_ln1_rstd, B, T, C, main_stream); + + // Accumulate gradients from this layer in a background stream. + if(last_step) { + floatX* const pointers[] = { + dl_ln1w, dl_ln1b, + dl_qkvw, dl_qkvb, + dl_attprojw, dl_attprojb, + dl_ln2w, dl_ln2b, + dl_fcw, dl_fcb, + dl_fcprojw, dl_fcprojb + }; + const size_t nelem[] = { + C, C, + 3 * C * C, 3 * C, + C * C, C, + C, C, + 4 * C * C, 4 * C, + C * 4 * C, C + }; + multi_gpu_async_reduce_gradient(pointers, nelem, &multi_gpu_config, main_stream); + } + } + encoder_backward(grads.wte, grads.wpe, scratchX, model->workload_indices, model->bucket_info, + dresidual, model->inputs, inputs, B, T, C, random_u32(&model->rng_state), main_stream); + + // Aggregate all gradients that are not part of the transformer blocks + if(last_step) { + // reduce all the losses within the current GPU (across all microsteps) + global_sum_deterministic(model->accumulated_mean_loss, acts.losses, B*T, main_stream); + // reduce loss across GPUs to a single, final float across all microsteps and GPUs + #if MULTI_GPU + ncclCheck(ncclAllReduce(model->accumulated_mean_loss, model->accumulated_mean_loss, sizeof(float), ncclFloat, ncclAvg, multi_gpu_config.nccl_comm, main_stream)); + #endif + cudaCheck(cudaMemcpyAsync(&model->mean_loss, model->accumulated_mean_loss, sizeof(float), cudaMemcpyDeviceToHost, main_stream)); + // reduce the gradients for non-transformer block parameters + floatX* const pointers[] = {grads.wte, grads.wpe, grads.lnfw, grads.lnfb}; + const size_t nelem[] = {Vp * C, T * C, C, C}; + multi_gpu_async_reduce_gradient(pointers, nelem, &multi_gpu_config, main_stream); + } + + cudaCheck(cudaDeviceSynchronize()); + if(last_step) { + model->mean_loss /= B*T*grad_accum_steps; + } else { + model->mean_loss = -1.f; // no loss available yet + } +} + +// Gets the offset of a specific tensor for a specific layer in the GPT2 model +// layer_id is ignored for weights that are not part of a transformer block +ShardInfo gpt2_get_tensor_at_layer(const GPT2 *model, int layer_id, int param_tensor_id) { + // first offset our way to the parameter tensor start + ptrdiff_t offset = 0; + for (int i = 0; i < param_tensor_id; i++) { + offset += (ptrdiff_t)model->param_elements[i]; + } + size_t size = model->param_elements[param_tensor_id] ; + // if we are in the transformer block, we need to additionally offset by the layer id + if(2 <= param_tensor_id && param_tensor_id <= 13) { + size /= model->config.num_layers; + offset += (ptrdiff_t)(layer_id * size); + } + return {offset, size}; +} + +float gpt2_calculate_grad_norm(GPT2 *model, MultiGpuConfig* multi_gpu_config) { + NVTX_RANGE_FN(); + floatX* grads_memory = (floatX*)model->grads_memory; + + // repurposing this buffer (which isn't needed now) to write grad norm into it + float* grad_norm_squared = (float*)model->acts.output; + float grad_norm_squared_cpu = 0.0f; + + int num_slices[2] = {1, model->config.num_layers}; + int max_num_block_sums = get_max_num_block_sums(num_slices, 2); + if (multi_gpu_config->zero_stage == 1) { + // because of the ncclReduceScatter() in backward, + // grads_memory only contains the averaged gradients at the local shards, + // so we only calculate the grad norm at the grads_memory belonging to the local shards + for (int i = 0; i < NUM_PARAMETER_TENSORS; i++) { + ShardInfo tensor = gpt2_get_tensor_at_layer(model, 0, i); + ShardInfo shard = multi_gpu_get_shard_offset(tensor.size, multi_gpu_config, 1); + ptrdiff_t offset = tensor.offset + shard.offset; + bool is_first_pass = (i == 0); + if((i < 2 || i > 13)) { + global_norm_squared(grad_norm_squared, grads_memory + offset, shard.size, 0, 1, + max_num_block_sums, is_first_pass, main_stream); + } else { + global_norm_squared(grad_norm_squared, grads_memory + offset, shard.size, tensor.size, model->config.num_layers, + max_num_block_sums, is_first_pass, main_stream); + } + } + global_sum_deterministic(grad_norm_squared, grad_norm_squared, max_num_block_sums, main_stream); +#if MULTI_GPU + // further sum the (partial) squared norm across all GPUs + ncclCheck(ncclAllReduce(grad_norm_squared, grad_norm_squared, sizeof(float), ncclFloat, ncclSum, multi_gpu_config->nccl_comm, main_stream)); +#endif + } else { + // in regular DDP, backward has averaged the gradients across all GPUs + // so each GPU can compute the squared norm over the whole grad vector, with no added comms needed + global_norm_squared(grad_norm_squared, grads_memory, model->num_parameters, 0, 1, max_num_block_sums, true, main_stream); + global_sum_deterministic(grad_norm_squared, grad_norm_squared, max_num_block_sums, main_stream); + } + cudaCheck(cudaMemcpy(&grad_norm_squared_cpu, grad_norm_squared, sizeof(float), cudaMemcpyDeviceToHost)); + float grad_norm_cpu = sqrtf(grad_norm_squared_cpu); + return grad_norm_cpu; +} + +void gpt2_update(GPT2 *model, float learning_rate, float beta1, float beta2, float eps, float weight_decay, float grad_scale, int t, + MultiGpuConfig* multi_gpu_config, bool init_from_master_only=false) { + // update the model parameters using the AdamW optimizer + // keep in mind that optimizer sharding (ZeRO-1) assigns different parameters to different GPUs + // so we may not be responsible for the entire parameter tensor + // also, this function was very simple a while back but become very complex, only because we want to + // selectively weight decay some, but not all tensors :( + // TODO: revisit and probably refactor this entire function + NVTX_RANGE_FN(); + if(model->grads_memory == nullptr || model->m_memory == nullptr || model->v_memory == nullptr) { + fprintf(stderr, "Need to allocate optimizer state before update"); + exit(EXIT_FAILURE); + } + + bool init_state = model->init_state; + if(init_state) { + model->init_state = false; + NvtxRange rng("InitOpt"); + cudaCheck(cudaMemset(model->m_memory, 0, multi_gpu_config->shard_num_parameters * sizeof(float))); + cudaCheck(cudaMemset(model->v_memory, 0, multi_gpu_config->shard_num_parameters * sizeof(float))); + } + + // save RNG state at this point so we can round from master weights identically when restoring from a checkpoint + model->rng_state_last_update = model->rng_state; + + // AdamW update + // handle adamw for all the transformer blocks + for (int i = 0; i < NUM_PARAMETER_TENSORS; i++) { + // generate a unique seed for each tensor + unsigned int seed = random_u32(&model->rng_state); + + int num_layers = model->config.num_layers; + if((i < 2 || i > 13)) { + num_layers = 1; + } + + ShardInfo tensor = gpt2_get_tensor_at_layer(model, 0, i); + ShardInfo shard = multi_gpu_get_shard_offset(tensor.size, multi_gpu_config, 1); + ptrdiff_t local_offset_full = tensor.offset + shard.offset; + ptrdiff_t local_offset_partial = tensor.offset / multi_gpu_config->num_processes; + + // we only want to weight decay the 2D tensors and leave all 1D tensors alone + // in particular this also decays the embedding weights, but this is ok: + // - the token embeddings are weight shared and participate in the final projection to logits + // - the position embeddings actively participate at every forward/backward pass + float wd = (i == 0 || i == 1 || i == 4 || i == 6 || i == 10 || i == 12) ? weight_decay : 0.0f; + floatX* param_ptr = (floatX*)model->params_memory + local_offset_full; + floatX* grad_ptr = (floatX*)model->grads_memory + local_offset_full; + + ptrdiff_t opt_state_offset = multi_gpu_config->zero_stage < 1 ? local_offset_full : local_offset_partial; + float* m_ptr = model->m_memory + opt_state_offset; + float* v_ptr = model->v_memory + opt_state_offset; + float* master_ptr = nullptr; + if (model->master_weights != nullptr) { master_ptr = model->master_weights + opt_state_offset; } + if(init_state && model->master_weights != nullptr ) { + size_t grid_size = CEIL_DIV(shard.size, 512); + copy_and_cast_kernel<<>>(master_ptr, param_ptr, shard.size, + shard.size, tensor.size); + cudaCheck(cudaGetLastError()); + } + + if (init_from_master_only) { + // when resuming training from a checkpoint with master weights (allows changing precision) + init_from_master(param_ptr, master_ptr, shard.size, tensor.size, shard.size, num_layers, seed, main_stream); + } else { + // ok finally call the kernel to update the weights with AdamW + adamw_update(param_ptr, master_ptr, grad_ptr, + m_ptr, v_ptr, + shard.size, tensor.size, tensor.size, shard.size, num_layers, + learning_rate, + beta1, beta2, t, eps, wd, grad_scale, seed, main_stream); + } + + if (multi_gpu_config->zero_stage == 1) { +#if MULTI_GPU + ncclCheck(ncclGroupStart()); + for(int l = 0; l < num_layers; ++l) { + // gather updated shards of model->params_memory from each process + ncclCheck(ncclAllGather(param_ptr + l * tensor.size, + (floatX*) model->params_memory + tensor.offset + l * tensor.size, + shard.size, ncclFloatX, + multi_gpu_config->nccl_comm, multi_gpu_config->nccl_stream)); + } + ncclCheck(ncclGroupEnd()); +#endif + } + } + + cudaCheck(cudaDeviceSynchronize()); +} + +float gpt2_estimate_mfu(GPT2 *model, int num_tokens, float dt) { + /* + Estimate model flops utilization (MFU) + ref: Section 2.1 of https://arxiv.org/pdf/2001.08361 + Note: Ideally, the N here would be only the parameters that actually + participate in matrix multiplications. In this N, we are over-estimating by + including LayerNorm params, biases, and the position embedding weights, + but these are very small terms. Also keep in mind that we would want to exclude + the token embedding weights, but in GPT-2 these are weight shared, so they + participate in the classifier matmul, so they are correct to be included in N. + Note 2: The first term (6 * N) in flops_per_token is all weight matmuls, the + second is the attention matmul, which is also usually a small contribution. + */ + size_t N = model->num_parameters; + int L = model->config.num_layers; + int C = model->config.channels; + int T = model->seq_len; + size_t flops_per_token = 6 * N + (size_t)6 * L * C * T; + size_t flops_per_step = flops_per_token * num_tokens; + // express our flops throughput as ratio of A100 bfloat16 peak flops + float flops_achieved = (float)flops_per_step * (1.0f / dt); // per second + float flops_promised = get_flops_promised(deviceProp.name, PRECISION_MODE) * 1e12f; + if(flops_promised < 0) { + return -1.f; // don't know + } + float mfu = flops_achieved / flops_promised; + return mfu; +} + +void gpt2_free(GPT2 *model) { + cudaFreeCheck(&model->params_memory); + cudaFreeCheck(&model->grads_memory); + cudaFreeCheck(&model->m_memory); + cudaFreeCheck(&model->v_memory); + cudaFreeCheck(&model->master_weights); + cudaFreeCheck(&model->acts_memory); + cudaFreeCheck(&model->inputs); + cudaFreeCheck(&model->targets); + cudaFreeCheck(&model->accumulated_mean_loss); + cudaCheck(cudaFreeHost(model->cpu_losses)); + free(model->workload_indices); + free(model->bucket_info); +} + +// ---------------------------------------------------------------------------- +// common init & free code for all of train/test/profile + +void common_start(bool override_enable_tf32 = true, bool print_device_info = true) { + + // get CUDA device infos + cudaCheck(cudaGetDeviceProperties(&deviceProp, multi_gpu_config.local_device_idx)); + if (print_device_info) { + printf("[System]\n"); + printf("Device %d: %s\n", multi_gpu_config.local_device_idx, deviceProp.name); + } + + // set up the cuda streams. atm everything is on the single main stream + cudaCheck(cudaStreamCreate(&main_stream)); + nvtxNameCudaStreamA(main_stream, "main stream"); + + // set up cuBLAS and cuBLASLt + cublasCheck(cublasLtCreate(&cublaslt_handle)); + cudaCheck(cudaMalloc(&cublaslt_workspace, cublaslt_workspace_size)); + + // TF32 precision is equivalent to torch.set_float32_matmul_precision('high') + bool enable_tf32 = PRECISION_MODE == PRECISION_FP32 && deviceProp.major >= 8 && override_enable_tf32; + cublas_compute = enable_tf32 ? CUBLAS_COMPUTE_32F_FAST_TF32 : CUBLAS_COMPUTE_32F; + + #ifdef ENABLE_CUDNN + create_cudnn(); + #endif +} + +void common_free(GPT2 &model) { + cudaCheck(cudaStreamDestroy(main_stream)); + cudaCheck(cudaFree(cublaslt_workspace)); + cublasCheck(cublasLtDestroy(cublaslt_handle)); + #ifdef ENABLE_CUDNN + destroy_cudnn(); + #endif +} + + +void save_state(const char* filename, int step, GPT2* model, DataLoader* loader) { + printf("Writing state to %s\n", filename); + FILE *state_file = fopenCheck(filename, "wb"); + int state_header[256]; + memset(state_header, 0, sizeof(state_header)); + // basic identifying information + state_header[0] = 20240527; // magic number + state_header[1] = 1; // version number + state_header[2] = multi_gpu_config.num_processes; // number of processes + state_header[3] = multi_gpu_config.process_rank; // rank of this process + state_header[4] = model->use_master_weights; // whether we're using fp32 master weights + state_header[5] = loader->should_shuffle; // shuffle state of the dataloader + // int main state, start at 10 to leave some padding + state_header[10] = step; // step of the optimization + // model rng state, start at 20 to leave some padding + *((unsigned long long*)&state_header[20]) = model->rng_state; // random number generator state + *((unsigned long long*)&state_header[22]) = model->rng_state_last_update; // last gpt2_update + // dataloader state, start at 30 to leave some padding + *((size_t*)&state_header[30]) = loader->current_shard_idx; // shard of the dataset + *((size_t*)&state_header[32]) = loader->current_sample_idx; // position in shard + fwriteCheck(state_header, sizeof(int), 256, state_file); + + // write AdamW m, v, and master_weights here (they are all float) + size_t shard_num_parameters = multi_gpu_config.shard_num_parameters; + device_to_file(state_file, model->m_memory, shard_num_parameters * sizeof(float), IO_BUF_SIZE, main_stream); + device_to_file(state_file, model->v_memory, shard_num_parameters * sizeof(float), IO_BUF_SIZE, main_stream); + if(model->use_master_weights) { + device_to_file(state_file, model->master_weights, shard_num_parameters * sizeof(float), IO_BUF_SIZE, main_stream); + } + + // write dataloader state if we are using the Permuted version of it + if (loader->should_shuffle) { + fwriteCheck(&loader->glob_result.gl_pathc, sizeof(size_t), 1, state_file); // number of shards + fwriteCheck(loader->shard_indices, sizeof(int), loader->glob_result.gl_pathc, state_file); + fwriteCheck(&loader->shard_num_samples, sizeof(size_t), 1, state_file); + fwriteCheck(loader->intra_shard_indices, sizeof(int), loader->shard_num_samples, state_file); + fwriteCheck(&loader->shuffle_rng, sizeof(mt19937_state), 1, state_file); + } + fcloseCheck(state_file); +} + +void load_state(int* step, GPT2* model, DataLoader* loader, const char* filename) { + FILE *state_file = fopenCheck(filename, "rb"); + int state_header[256]; + freadCheck(state_header, sizeof(int), 256, state_file); + assert(state_header[0] == 20240527); // magic number + assert(state_header[1] == 1); // version number + assert(state_header[2] == multi_gpu_config.num_processes); // number of processes + assert(state_header[3] == multi_gpu_config.process_rank); // rank of this process + int use_master_weights = state_header[4]; // whether we're using fp32 master weights + int should_shuffle = state_header[5]; // shuffle state of the dataloader + *step = state_header[10]; // step of the optimization + model->rng_state = *((unsigned long long*)&state_header[20]); // random number generator state + model->rng_state_last_update = *((unsigned long long*)&state_header[22]); // last gpt2_update + size_t current_shard_idx = *((size_t*)&state_header[30]); // shard index + size_t current_sample_idx = *((size_t*)&state_header[32]); // position in shard + + // read AdamW m, v, master_weights (they are all float) + // allocate all the needed memory as necessary + size_t shard_num_parameters = multi_gpu_config.shard_num_parameters; + if(use_master_weights == 1 && !model->use_master_weights) { + printf0("Warning: Master weights are present in state, but not enabled for current run."); + } else if (use_master_weights == 0 && model->use_master_weights) { + printf0("Error: Master weights requested, but not present in state file."); + exit(EXIT_FAILURE); + } + + model->init_state = false; // we just got the state from file, no need to do first-touch init + assert(model->m_memory != nullptr); + assert(model->v_memory != nullptr); + file_to_device(model->m_memory, state_file, shard_num_parameters * sizeof(float), IO_BUF_SIZE, main_stream); + file_to_device(model->v_memory, state_file, shard_num_parameters * sizeof(float), IO_BUF_SIZE, main_stream); + if(model->use_master_weights) { + assert(model->master_weights != nullptr); + file_to_device(model->master_weights, state_file, shard_num_parameters * sizeof(float), IO_BUF_SIZE, main_stream); + // restore weights from the master weights using the RNG state before last weight update + model->rng_state = model->rng_state_last_update; + gpt2_update(model, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0, &multi_gpu_config, /* init_from_master_only*/ true); + model->rng_state = *((unsigned long long*)&state_header[20]); // use final RNG state from checkpoint after this + } + + // revive the DataLoader object and its state + loader->should_shuffle = should_shuffle; + if (should_shuffle == 1) { + // ensure the number of shards matches + size_t glob_result_gl_pathc; + freadCheck(&glob_result_gl_pathc, sizeof(size_t), 1, state_file); + assert(glob_result_gl_pathc == loader->glob_result.gl_pathc); + // read the shard indices + loader->shard_indices = (int*)mallocCheck(loader->glob_result.gl_pathc * sizeof(int)); + freadCheck(loader->shard_indices, sizeof(int), loader->glob_result.gl_pathc, state_file); + // ensure the number of samples matches + size_t shard_num_samples; + freadCheck(&shard_num_samples, sizeof(size_t), 1, state_file); + assert(shard_num_samples == loader->shard_num_samples); + // read the intra-shard indices + loader->intra_shard_indices = (int*)mallocCheck(loader->shard_num_samples * sizeof(int)); + freadCheck(loader->intra_shard_indices, sizeof(int), loader->shard_num_samples, state_file); + // read the shuffle rng state + freadCheck(&loader->shuffle_rng, sizeof(mt19937_state), 1, state_file); + } + dataloader_resume(loader, current_shard_idx, current_sample_idx); + + // all done, close state file + fcloseCheck(state_file); +} + +void write_checkpoint(const char* output_log_dir, int step, GPT2* model, DataLoader* train_loader, MultiGpuConfig* multi_gpu_config) { + // a checkpoint contains: model weights, optimizer/dataloader state, and a DONE file + printf0("Writing checkpoint at step %d\n", step); + int rank = multi_gpu_config->process_rank; + // only rank 0 writes the model file because it is the same across all ranks + if (rank == 0) { + snprintf(filename_buffer, sizeof(filename_buffer), "%s/model_%08d.bin", output_log_dir, step); + gpt2_write_to_checkpoint(model, filename_buffer); + } + // all ranks write their state file + snprintf(filename_buffer, sizeof(filename_buffer), "%s/state_%08d_%05d.bin", output_log_dir, step, rank); + save_state(filename_buffer, step, model, train_loader); + // DONE file is a signal that this checkpoint as a whole is complete + multi_gpu_barrier(multi_gpu_config); + if (rank == 0) { + snprintf(filename_buffer, sizeof(filename_buffer), "%s/DONE_%08d", output_log_dir, step); + FILE* done_file = fopenCheck(filename_buffer, "w"); + fcloseCheck(done_file); + } +} + +void delete_checkpoint(const char* output_log_dir, int step, MultiGpuConfig* multi_gpu_config) { + // mirrors write_checkpoint function, cleans up checkpoint from disk + printf0("Deleting checkpoint at step %d\n", step); + int rank = multi_gpu_config->process_rank; + if (rank == 0) { + snprintf(filename_buffer, sizeof(filename_buffer), "%s/model_%08d.bin", output_log_dir, step); + remove(filename_buffer); + } + snprintf(filename_buffer, sizeof(filename_buffer), "%s/state_%08d_%05d.bin", output_log_dir, step, rank); + remove(filename_buffer); + if (rank == 0) { + snprintf(filename_buffer, sizeof(filename_buffer), "%s/DONE_%08d", output_log_dir, step); + remove(filename_buffer); + } +} + +#ifndef TESTING +// if we are TESTING (see test_gpt2.cu), we'll skip everything below this point + +// ---------------------------------------------------------------------------- +// training resumption logic, very useful when jobs crash once in a while +// the goal is that we can resume optimization from any checkpoint, bit-perfect +// note that "state" refers to things not already saved in the model checkpoint file + +// ---------------------------------------------------------------------------- +// CLI, poor man's argparse +// (all single letters have been claimed now) + +void error_usage() { + fprintf(stderr, "Usage: ./train_gpt2cu [options]\n"); + fprintf(stderr, "Options:\n"); + // file system input / output + fprintf(stderr, " -i train data filename pattern (default = dev/data/tinyshakespeare/tiny_shakespeare_train.bin)\n"); + fprintf(stderr, " -j val data filename pattern (default = dev/data/tinyshakespeare/tiny_shakespeare_val.bin)\n"); + fprintf(stderr, " -e input .bin filename or descriptor, see code comments as docs. (default = gpt2_124M_bf16.bin)\n"); + fprintf(stderr, " -o output log dir (default = NULL, no logging)\n"); + fprintf(stderr, " -lg log gpu info every x steps (default = -1; disabled)\n"); + fprintf(stderr, " -n write optimization checkpoints every how many steps? (default 0, don't)\n"); + fprintf(stderr, " -nk max number of checkpoints to keep in the directory, removing old ones (0 = disable, default)\n"); + fprintf(stderr, " -nm every how many step checkpoints are considered major? major checkpoints never get deleted.\n"); + fprintf(stderr, " -y resume optimization found inside output log dir? (0=restart/overwrite, 1=resume/append)\n"); + // token layout for each step of the optimization + fprintf(stderr, " -b (per-GPU, micro) batch size B (default = 4)\n"); + fprintf(stderr, " -t sequence length T (default = 1024)\n"); + fprintf(stderr, " -d total desired batch size (default = B * T * num_processes, i.e. no grad accumulation\n"); + // workload (number of steps) + fprintf(stderr, " -x max_steps of optimization to run (-1 (default) = disable, run 1 epoch)\n"); + // optimization + fprintf(stderr, " -k learning rate scheduler (default = cosine)\n"); + fprintf(stderr, " -l learning rate (default = 3e-4f)\n"); + fprintf(stderr, " -u learning rate warmup iterations (default = 0, no warmup)\n"); + fprintf(stderr, " -q learning rate decay: final fraction, at end of training (default = 1.0 (no decay))\n"); + fprintf(stderr, " -c weight decay (default = 0.0f)\n"); + fprintf(stderr, " -sl outlier stability: skip update if loss goes above this in zscore (0.0f=off)\n"); + fprintf(stderr, " -sg outlier stability: skip update if grad_norm goes above this in zscore (0.0f=off)\n"); + // evaluation + fprintf(stderr, " -v val_loss_every, how often we evaluate val loss (default = 20)\n"); + fprintf(stderr, " -m val_max_steps, up to how many val batches to estimate val loss? (default = 20)\n"); + fprintf(stderr, " -s sample_every, how often we inference the model (default = 20)\n"); + fprintf(stderr, " -g genT, how many steps of inference we do (default = 64)\n"); + fprintf(stderr, " -h hellaswag eval run? (default = 0)\n"); + // debugging + fprintf(stderr, " -a overfit a single batch? 0/1. useful for debugging\n"); + // numerics + fprintf(stderr, " -f enable_tf32 override (default: 1, set to 0 to disable tf32)\n"); + fprintf(stderr, " -w keep f32 copy of weights for the optimizer? (default: 1)\n"); + fprintf(stderr, " -ge gelu fusion: 0=none, 1=forward, 2=forward+backward (default: 2 for >=SM90, 0 for older GPUs)\n"); + // memory management + fprintf(stderr, " -z zero_stage, Zero Optimization Stage, 0,1,2,3 (default = 0)\n"); + fprintf(stderr, " -r recompute: less memory but less speed. (default = 1), 0|1|2 = none,gelu,gelu+ln\n"); + // multi-node settings + fprintf(stderr, " -pn num_processes (default = 1)\n"); + fprintf(stderr, " -pr process_rank (default = 0)\n"); + fprintf(stderr, " -pg gpus_per_node (default = 8)\n"); + fprintf(stderr, " -pm nccl_init_method: tcp,fs,mpi (default = mpi)\n"); + fprintf(stderr, " -ps server_ip - used only when nccl_init_method is tcp (default = -1)\n"); + fprintf(stderr, " -pp fs_path - used only when nccl_init_method is fs (default = /tmp)\n"); + exit(EXIT_FAILURE); +} + +// ---------------------------------------------------------------------------- +// main training loop +int main(int argc, char *argv[]) { + // read in the (optional) command line arguments + const char* train_data_pattern = "dev/data/tinyshakespeare/tiny_shakespeare_train.bin"; + const char* val_data_pattern = "dev/data/tinyshakespeare/tiny_shakespeare_val.bin"; + const char* load_filename = "gpt2_124M_bf16.bin"; // bf16 weights of the model + const char* lr_scheduler_type = "cosine"; + const char* output_log_dir = NULL; + int checkpoint_every = 0; // write checkpoints every how many steps? + int checkpoints_keep = 0; // how long checkpoint history do we keep? (in units of checkpoints) + int major_checkpoint_every = 0; // major checkpoints never get deleted when maintaining history + int resume = 0; // resume the optimization, if one is found inside output_log_dir? + int B = 4; // batch size + int T = 1024; // sequence length max + int total_batch_size = -1; // will be calculated down below later, if not provided + float learning_rate = 3e-4f; + int log_gpu_every = -1; + int warmup_iterations = 0; + float final_learning_rate_frac = 1.0f; // final fraction of learning rate, at end of training + float weight_decay = 0.0f; + float skip_update_lossz = 0.0f; // skip update if loss goes above this in zscore + float skip_update_gradz = 0.0f; // skip update if grad_norm goes above this in zscore + int val_loss_every = 20; // every how many steps do we eval validation loss? + int val_max_steps = 20; // how many batches max do we eval for validation loss? + int sample_every = 20; // every how many steps to do inference? + int genT = 64; // number of steps of inference we will do + int overfit_single_batch = 0; // useful for debugging, 1 = only load a single data batch once + int max_steps = -1; + int override_enable_tf32 = 1; + int use_master_weights = 1; + int gelu_fusion = -1; // 0 = none, 1 = forward, 2 = forward+backward (-1 => per-GPU default) + int recompute = 1; // recompute during backward setting, 0 = none, 1 = recompute gelu + int zero_stage = 0; // Zero Optimization Stage for Multi-GPU training + int hellaswag_eval = 0; + // multi-node settings + int num_processes = 1; // this should be set by the slurm environment + int process_rank = 0; // this should be set by the slurm environment + int gpus_per_node = 8; // this should be set by the slurm environment + char nccl_init_method[256] = "mpi"; // "tcp" or "fs" or "mpi" + char server_ip[256] = ""; // used if init_method set to "tcp" -> set to your server ip address + char fs_path[256] = ""; // used if init_method set to "fs" -> set to a shared filesystem path + for (int i = 1; i < argc; i+=2) { + if (i + 1 >= argc) { error_usage(); } // must have arg after flag + if (argv[i][0] != '-') { error_usage(); } // must start with dash + if (!(strlen(argv[i]) == 2 || strlen(argv[i]) == 3)) { error_usage(); } // must be -x[y] (one dash, one or two letters) + // read in the args + if (argv[i][1] == 'i') { train_data_pattern = argv[i+1]; } + else if (argv[i][1] == 'j') { val_data_pattern = argv[i+1]; } + else if (argv[i][1] == 'e') { load_filename = argv[i+1]; } + else if (argv[i][1] == 'o') { output_log_dir = argv[i+1]; } + else if (argv[i][1] == 'n' && argv[i][2] == '\0') { checkpoint_every = atoi(argv[i+1]); } + else if (argv[i][1] == 'y') { resume = atoi(argv[i+1]); } + else if (argv[i][1] == 'b') { B = atoi(argv[i+1]); } // Per-GPU (micro) batch size + else if (argv[i][1] == 't') { T = atoi(argv[i+1]); } + else if (argv[i][1] == 'd') { total_batch_size = atoi(argv[i+1]); } + else if (argv[i][1] == 'l' && argv[i][2] == '\0') { learning_rate = atof(argv[i+1]); } + else if (argv[i][1] == 'l' && argv[i][2] == 'g') { log_gpu_every = atoi(argv[i+1]); } + else if (argv[i][1] == 'u') { warmup_iterations = atoi(argv[i+1]); } + else if (argv[i][1] == 'q') { final_learning_rate_frac = atof(argv[i+1]); } + else if (argv[i][1] == 'c') { weight_decay = atof(argv[i+1]); } + else if (argv[i][1] == 'x') { max_steps = atoi(argv[i+1]); } + else if (argv[i][1] == 'v') { val_loss_every = atoi(argv[i+1]); } + else if (argv[i][1] == 'm') { val_max_steps = atoi(argv[i+1]); } + else if (argv[i][1] == 's' && argv[i][2] == '\0') { sample_every = atoi(argv[i+1]); } + else if (argv[i][1] == 'g' && argv[i][2] == 'e') { gelu_fusion = atoi(argv[i+1]); } + else if (argv[i][1] == 'g') { genT = atoi(argv[i+1]); } + else if (argv[i][1] == 'a') { overfit_single_batch = atoi(argv[i+1]); } + else if (argv[i][1] == 'f') { override_enable_tf32 = atoi(argv[i+1]); } + else if (argv[i][1] == 'w') { use_master_weights = atoi(argv[i+1]); } + else if (argv[i][1] == 'z') { zero_stage = atoi(argv[i+1]); } + else if (argv[i][1] == 'r') { recompute = atoi(argv[i+1]); } + else if (argv[i][1] == 'h') { hellaswag_eval = atoi(argv[i+1]); } + else if (argv[i][1] == 'k') { lr_scheduler_type = argv[i+1]; } + else if (argv[i][1] == 'p' && argv[i][2] == 'i') { strcpy(nccl_init_method, argv[i+1]); } + else if (argv[i][1] == 'p' && argv[i][2] == 'f') { strcpy(fs_path, argv[i+1]); } + else if (argv[i][1] == 'p' && argv[i][2] == 's') { strcpy(server_ip, argv[i+1]); } + else if (argv[i][1] == 'p' && argv[i][2] == 'n') { num_processes = atoi(argv[i+1]); } + else if (argv[i][1] == 'p' && argv[i][2] == 'r') { process_rank = atoi(argv[i+1]); } + else if (argv[i][1] == 'p' && argv[i][2] == 'g') { gpus_per_node = atoi(argv[i+1]); } + else if (argv[i][1] == 's' && argv[i][2] == 'l') { skip_update_lossz = atof(argv[i+1]); } + else if (argv[i][1] == 's' && argv[i][2] == 'g') { skip_update_gradz = atof(argv[i+1]); } + else if (argv[i][1] == 'n' && argv[i][2] == 'k') { checkpoints_keep = atoi(argv[i+1]); } + else if (argv[i][1] == 'n' && argv[i][2] == 'm') { major_checkpoint_every = atoi(argv[i+1]); } + else { error_usage(); } + } + + multi_gpu_config = multi_gpu_config_init(num_processes, process_rank, gpus_per_node, server_ip, fs_path, nccl_init_method); + common_start(override_enable_tf32, false); // common init code for train/test/profile + + // should do a bit more error checking here + assert(warmup_iterations >= 0); + if (output_log_dir != NULL) { + assert(strlen(output_log_dir) < 400); // careful bunch of hardcoded snprintf around this + } + int tokens_per_fwdbwd = B * T * multi_gpu_config.num_processes; // one micro-batch processes this many tokens + // calculate sensible default for total batch size as assuming no gradient accumulation + if (total_batch_size == -1) { total_batch_size = tokens_per_fwdbwd; } + // in the future, we might want to set gelu fusion to 2 for SM90+ and 0 for other GPUs + if (gelu_fusion == -1) { gelu_fusion = 0; } // (deviceProp.major >= 9) ? 2 : 0; } // in gpt2_init_common for test_gpt2cu... + // calculate the number of gradient accumulation steps from the desired total batch size + assert(total_batch_size % tokens_per_fwdbwd == 0); + int grad_accum_steps = total_batch_size / tokens_per_fwdbwd; + // if we're only overfitting a single batch for debugging, let's overfit the first batch + // from val instead of train split, because val is smaller and faster. (train_gpt2.py does the same) + if (overfit_single_batch == 1) { train_data_pattern = val_data_pattern; } + printf0("+-----------------------+----------------------------------------------------+\n"); + printf0("| Parameter | Value |\n"); + printf0("+-----------------------+----------------------------------------------------+\n"); + printf0("| train data pattern | %-50s |\n", train_data_pattern); + printf0("| val data pattern | %-50s |\n", val_data_pattern); + printf0("| output log dir | %-50s |\n", output_log_dir == NULL ? "NULL" : output_log_dir); + printf0("| checkpoint_every | %-50d |\n", checkpoint_every); + printf0("| resume | %-50d |\n", resume); + printf0("| micro batch size B | %-50d |\n", B); + printf0("| sequence length T | %-50d |\n", T); + printf0("| total batch size | %-50d |\n", total_batch_size); + printf0("| LR scheduler | %-50s |\n", lr_scheduler_type); + printf0("| learning rate (LR) | %-50e |\n", learning_rate); + printf0("| warmup iterations | %-50d |\n", warmup_iterations); + printf0("| final LR fraction | %-50e |\n", final_learning_rate_frac); + printf0("| weight decay | %-50e |\n", weight_decay); + printf0("| skip update lossz | %-50f |\n", skip_update_lossz); + printf0("| skip update gradz | %-50f |\n", skip_update_gradz); + printf0("| max_steps | %-50d |\n", max_steps); + printf0("| val_loss_every | %-50d |\n", val_loss_every); + printf0("| val_max_steps | %-50d |\n", val_max_steps); + printf0("| sample_every | %-50d |\n", sample_every); + printf0("| genT | %-50d |\n", genT); + printf0("| overfit_single_batch | %-50d |\n", overfit_single_batch); + printf0("| use_master_weights | %-50s |\n", use_master_weights ? "enabled" : "disabled"); + printf0("| gelu_fusion | %-50d |\n", gelu_fusion); + printf0("| recompute | %-50d |\n", recompute); + printf0("+-----------------------+----------------------------------------------------+\n"); + const char* precision_str = (PRECISION_MODE == PRECISION_FP32) + ? (cublas_compute == CUBLAS_COMPUTE_32F_FAST_TF32 ? "TF32" : "FP32") + : (PRECISION_MODE == PRECISION_FP16 ? "FP16" : "BF16"); + printf0("| device | %-50s |\n", deviceProp.name); + printf0("| peak TFlops | %-50.1f |\n", get_flops_promised(deviceProp.name, PRECISION_MODE)); + printf0("| precision | %-50s |\n", precision_str); + printf0("+-----------------------+----------------------------------------------------+\n"); + + // figure out if we are going to be resuming the optimization + int resuming = 0; + // find the DONE file with the highest step count + int resume_max_step = find_max_step(output_log_dir); + if (resume == 1) { // is -y 1 resume flag set? + assert(output_log_dir != NULL); + if (resume_max_step != -1) { + resuming = 1; // -y 1 is set, and we found a checkpoint we can resume from + snprintf(filename_buffer, sizeof(filename_buffer), "%s/model_%08d.bin", output_log_dir, resume_max_step); + } + } + + // build the GPT-2 model + GPT2 model; + gpt2_init_common(&model); + if (resuming == 1) { + // if `-y 1` was set, then we are resuming from the latest checkpoint + // if we are using master weights, we'll init them later inside load_state() + bool weight_init = !use_master_weights; + gpt2_build_from_checkpoint(&model, filename_buffer, weight_init); + } else if (ends_with_bin(load_filename)) { + // otherwise, if this is a .bin file, we assume it's a model, let's init from it + gpt2_build_from_checkpoint(&model, load_filename); + } else { + // if it's not .bin, it could be a "special descriptor". This descriptor is used to + // construct GPT-2 / GPT-3 models in a convenient format. See the function for docs. + gpt_build_from_descriptor(&model, load_filename); + } + + model.use_master_weights = use_master_weights; + model.gelu_fusion = gelu_fusion; + model.recompute = recompute; + printf0("| weight init method | %-50s |\n", resuming == 1 ? "intermediate checkpoint" : load_filename); + printf0("| max_sequence_length T | %-50d |\n", model.config.max_seq_len); + printf0("| vocab_size V | %-50d |\n", model.config.vocab_size); + printf0("| padded_vocab_size Vp | %-50d |\n", model.config.padded_vocab_size); + printf0("| num_layers L | %-50d |\n", model.config.num_layers); + printf0("| num_heads NH | %-50d |\n", model.config.num_heads); + printf0("| channels C | %-50d |\n", model.config.channels); + printf0("| num_parameters | %-50zu |\n", model.num_parameters); + printf0("+-----------------------+----------------------------------------------------+\n"); + + // build DataLoaders for both train and val + int permute_train_loader = (overfit_single_batch == 1) ? 0 : 1; + DataLoader train_loader, val_loader; + dataloader_init(&train_loader, train_data_pattern, B, T, multi_gpu_config.process_rank, multi_gpu_config.num_processes, permute_train_loader); + dataloader_init(&val_loader, val_data_pattern, B, T, multi_gpu_config.process_rank, multi_gpu_config.num_processes, 0); + // figure out the number of training steps we will run for + int train_num_batches = max_steps; // passed in from command line + if (train_num_batches == -1) { + // sensible default is to train for exactly one epoch + size_t ntok = train_loader.num_tokens; + // the number of (outer loop) steps each process should take for us to reach one epoch + train_num_batches = ntok / total_batch_size; + } + // figure out the number of validation steps to run for + int val_num_batches = val_max_steps; // passed in from command line + if (val_num_batches == -1) { + // sensible default is to evaluate the full validation split + size_t ntok = val_loader.num_tokens; + // note that unlike the training loop, there is no gradient accumulation inner loop here + val_num_batches = ntok / tokens_per_fwdbwd; + } + printf0("| train_num_batches | %-50d |\n", train_num_batches); + printf0("| val_num_batches | %-50d |\n", val_num_batches); + printf0("+-----------------------+----------------------------------------------------+\n"); + + // build an EvalLoader for HellaSwag + EvalLoader eval_loader; + const char* hellaswag_path = "dev/data/hellaswag/hellaswag_val.bin"; + const bool hellaswag_available = access(hellaswag_path, F_OK) == 0; + const bool run_hellaswag = hellaswag_eval && hellaswag_available; + if (run_hellaswag) { + evalloader_init(&eval_loader, hellaswag_path, B, T, multi_gpu_config.process_rank, multi_gpu_config.num_processes); + } + printf0("| run hellaswag | %-50s |\n", run_hellaswag ? "yes" : "no"); + printf0("+-----------------------+----------------------------------------------------+\n"); + + // pretty print in a table the multi-gpu configuration as well + set_zero_configs(&multi_gpu_config, zero_stage, model.num_parameters); + printf0("| num_processes | %-50d |\n", multi_gpu_config.num_processes); + printf0("| zero_stage | %-50d |\n", multi_gpu_config.zero_stage); + printf0("+-----------------------+----------------------------------------------------+\n"); + + // prints outside of pretty table to here and below + if (!hellaswag_available) { + printf0("HellaSwag eval not found at %s, skipping its evaluation\n", hellaswag_path); + printf0("You can run `python dev/data/hellaswag.py` to export and use it with `-h 1`.\n"); + } + // more prints related to allocations from gpt2_build_from_checkpoint down here to not mess up our table above + printf0("num_parameters: %zu => bytes: %zu\n", model.num_parameters, model.num_parameters_bytes); + printf0("allocated %d MiB for model parameters\n", (int)round(model.num_parameters_bytes / (1024 * 1024))); + // few more prints for gradient accumulation math up above + printf0("batch_size B=%d * seq_len T=%d * num_processes=%d and total_batch_size=%d\n", + B, T, multi_gpu_config.num_processes, total_batch_size); + printf0("=> setting grad_accum_steps=%d\n", grad_accum_steps); + + // set up logging + if (multi_gpu_config.process_rank == 0) { create_dir_if_not_exists(output_log_dir); } + Logger logger; + logger_init(&logger, output_log_dir, multi_gpu_config.process_rank, resume); + + // set up the Tokenizer + Tokenizer tokenizer; + tokenizer_init(&tokenizer, "gpt2_tokenizer.bin"); + + // set up learning rate scheduler + LearningRateScheduler lr_scheduler; + lr_scheduler_init(&lr_scheduler, lr_scheduler_type, learning_rate, + warmup_iterations, train_num_batches, final_learning_rate_frac); + + // some memory for generating samples from the model + int* gen_tokens = (int*)mallocCheck(B * T * sizeof(int)); + floatX* cpu_logits_raw = (floatX*)mallocCheck(model.config.vocab_size * sizeof(floatX)); + float* cpu_logits = (float*)mallocCheck(model.config.vocab_size * sizeof(float)); + + // if we found a checkpoint to resume from, load the optimization state + int step = 0; + gpt2_allocate_state(&model, B, T); + if (resuming == 1) { + snprintf(filename_buffer, sizeof(filename_buffer), "%s/state_%08d_%05d.bin", output_log_dir, resume_max_step, multi_gpu_config.process_rank); + load_state(&step, &model, &train_loader, filename_buffer); + } + + // init an OutlierDetector the training loss + OutlierDetector loss_outlier_detector, grad_norm_outlier_detector; + init_detector(&loss_outlier_detector); + init_detector(&grad_norm_outlier_detector); + + // do some checks here before we kick off training + // cross-check the desired sequence length T with the model's max sequence length + if (T < model.config.max_seq_len) { + printf0("!!!!!!!!\n"); + printf0("WARNING:\n"); + printf0("- The training sequence length is: T=%d (set with -t)\n", T); + printf0("- The model's max sequence length is: max_seq_len=%d\n", model.config.max_seq_len); + printf0("You are attempting to train with a sequence length shorter than the model's max.\n"); + printf0("This will lead to unused parameters in the wpe position embedding weights.\n"); + printf0("If you know what you're doing you can ignore this warning.\n"); + printf0("If you're like ???, you are most likely misconfiguring your training run.\n"); + printf0("---> HINT: If you're training GPT-2 use -t 1024. If GPT-3, use -t 2048.\n"); + printf0("!!!!!!!!\n"); + } + // in any case, this must be true or we'd index beyond the model's wpe (position embedding table) + assert(T <= model.config.max_seq_len); + + // train + cudaEvent_t start, end; + cudaCheck(cudaEventCreate(&start)); + cudaCheck(cudaEventCreate(&end)); + cudaCheck(cudaProfilerStart()); + double total_sum_iteration_time_s = 0.0; + float ema_tokens_per_second = 0.0f; + for (; step <= train_num_batches; step++) { + NvtxRange step_range("Train step", step); + + int last_step = step == train_num_batches; + + // once in a while estimate the validation loss (all processes collaborate) + if (step % val_loss_every == 0 || last_step) { + NvtxRange validation_range("validation"); + float val_loss = 0.0f; + dataloader_reset(&val_loader); + for (int i = 0; i < val_num_batches; i++) { + dataloader_next_batch(&val_loader); + val_loss += gpt2_validate(&model, val_loader.inputs, val_loader.targets, B, T); + } + val_loss /= val_num_batches; + val_loss = multi_gpu_cpu_float_sum(val_loss, &multi_gpu_config) / multi_gpu_config.num_processes; + printf0("val loss %f\n", val_loss); + logger_log_val(&logger, step, val_loss); + } + + // once in a while estimate HellaSwag accuracy (all processes collaborate) + if (run_hellaswag && + ((step > 0 && step % val_loss_every == 0) || last_step)) { + NvtxRange evaluation_range("evaluation"); + float eval_acc_norm = 0.0f; + evalloader_reset(&eval_loader); + for (int i = 0; i < eval_loader.num_batches; i++) { + if (i % 10 == 0) { printf("evaluating HellaSwag: %d/%d\r", i, eval_loader.num_batches); } + evalloader_next_batch(&eval_loader); + gpt2_validate(&model, eval_loader.inputs, eval_loader.targets, B, T); + int correct = evalloader_stat_losses(&eval_loader, model.cpu_losses); + eval_acc_norm += (float)correct; + } + // careful because not all ranks may have the exact same allocation of number of examples + eval_acc_norm = multi_gpu_cpu_float_sum(eval_acc_norm, &multi_gpu_config); + printf0("HellaSwag: %d/%d = %f\n", (int)eval_acc_norm, eval_loader.num_examples, eval_acc_norm / eval_loader.num_examples); + logger_log_eval(&logger, step, eval_acc_norm / eval_loader.num_examples); + } + + // once in a while do model inference to print generated text (only rank 0) + if (multi_gpu_config.process_rank == 0 && sample_every > 0 && + (step > 0 && (step % sample_every) == 0 || last_step)) { + NvtxRange generation_range("generation"); + unsigned long long sample_rng_state = 1337; + // fill up gen_tokens with the <|endoftext|> token, which kicks off the generation + int eot_token = tokenizer.eot_token; + for(int i = 0; i < B * T; ++i) { + gen_tokens[i] = eot_token; + } + // now sample from the model autoregressively + printf("generating:\n---\n"); + for (int t = 1; t < genT; t++) { + NvtxRange generation_range("Generation step", t); + // we try not to be too wasteful for inference by not calculating all of B,T + // Using a smaller B is always bit-for-bit identical, but T is more tricky + // for non-CUDNN, we need to make sure the attention buffer is memset to 0 + // for cuDNN, it might suddenly decide to use a slightly different algorithm... + // on cuDNN 9.2.1 with cuDNN FrontEnd 1.5.2, T >= 256 seems bit-for-bit identical + // (but even if it wasn't fully identical that's probably not the end of the world) + // note this is still somewhat wasteful because we don't have a KV cache! + gpt2_forward(&model, gen_tokens, 1, CEIL_DIV(t, min(T,256)) * min(T,256)); + // get the V-dimensional vector probs[0, t-1, :] + floatX* logits = model.acts.output + (t - 1) * model.config.padded_vocab_size; + // move probs back to CPU and sample (note we only move the first vocab_size logits, ignoring the padding) + cudaCheck(cudaMemcpy(cpu_logits_raw, logits, model.config.vocab_size * sizeof(floatX), cudaMemcpyDeviceToHost)); + // convert to FP32 into cpu_logits (this does nothing useful if floatX == float) + for (int i = 0; i < model.config.vocab_size; i++) { + cpu_logits[i] = (float)cpu_logits_raw[i]; + } + // sample the next token + float coin = random_f32(&sample_rng_state); + int next_token = sample_softmax(cpu_logits, model.config.vocab_size, coin); + gen_tokens[t] = next_token; + // print the generated token, either using the Tokenizer or a fallback + if (tokenizer.init_ok) { + const char* token_str = tokenizer_decode(&tokenizer, next_token); + safe_printf(token_str); + } else { + // fall back to printing the token id + printf("%d ", next_token); + } + fflush(stdout); + } + printf("\n---\n"); + } + + // once in a while checkpoint the optimization state (all ranks) + if ((checkpoint_every > 0 && output_log_dir != NULL && resuming == 0) && + ((step > 0 && step % checkpoint_every == 0) || last_step)) { + // writes model .bin file, state .bin files, and DONE file for step + write_checkpoint(output_log_dir, step, &model, &train_loader, &multi_gpu_config); + // we only keep checkpoints_keep checkpoints on disk to save space + // so now that we wrote a new checkpoint, delete one old one (unless it is a "major" checkpoint) + // we only do this is checkpoint keeping is turned on (checkpoints_keep > 0) + int step_delete = step - checkpoints_keep * checkpoint_every; + if (checkpoints_keep > 0 && step_delete > 0 && + (major_checkpoint_every == 0 || step_delete % major_checkpoint_every != 0) + ) { + delete_checkpoint(output_log_dir, step_delete, &multi_gpu_config); + } + } + resuming = 0; + + // bit confusing: we want to make sure to eval and sample on 0th iteration + // but also after the very last iteration. so we loop for step <= train_num_batches + // instead of just < train_num_batches (one extra due to <=), only to do + // the validation/sampling one last time, and then we break right here as we're done. + if (last_step) { break; } + + // --------------- TRAINING SECTION BEGIN ----------------- + if (overfit_single_batch == 1) { + // if we are trying to overfit a single batch, we reset the loader here + dataloader_reset(&train_loader); + } + // do one training step, doing forward/backward/update on total_batch_size tokens + cudaCheck(cudaEventRecord(start)); + // gradient and loss accumulation loop over micro-batches + for (int micro_step = 0; micro_step < grad_accum_steps; micro_step++) { + // fetch the next data batch + dataloader_next_batch(&train_loader); + // forward pass. note that we pass in grad_accum_steps, which scales down the loss + gpt2_forward(&model, train_loader.inputs, B, T); + // backward pass. all model params accumulate gradients with += inside this inner loop + gpt2_backward_and_reduce(&model, train_loader.inputs, train_loader.targets, grad_accum_steps, micro_step); + } + float zloss = (float)(update_detector(&loss_outlier_detector, (double)model.mean_loss)); // loss z-score + // fetch the next learning rate + float step_learning_rate = get_learning_rate(&lr_scheduler, step); + // calculate the gradient norm and how much we wish to scale the gradient + float grad_norm = gpt2_calculate_grad_norm(&model, &multi_gpu_config); + float zgrad = (float)(update_detector(&grad_norm_outlier_detector, (double)grad_norm)); // grad z-score + // update the model parameters + if (isfinite(zloss) && skip_update_lossz != 0.0f && zloss > skip_update_lossz) { + printf0("skipping update due to loss z-score of %f\n", zloss); + } else if (isfinite(zgrad) && skip_update_gradz != 0.0f && zgrad > skip_update_gradz) { + printf0("skipping update due to grad z-score of %f\n", zgrad); + } else { + // clip the gradient norm to a maximum value + float grad_clip = 1.0f; + float grad_scale = (grad_norm > grad_clip) ? grad_clip / grad_norm : 1.0f; + gpt2_update(&model, step_learning_rate, 0.9f, 0.95f, 1e-8f, weight_decay, grad_scale, step+1, &multi_gpu_config); + } + cudaCheck(cudaEventRecord(end)); + cudaCheck(cudaEventSynchronize(end)); // wait for the end event to finish to get correct timings + // --------------- TRAINING SECTION END ------------------- + // everything that follows now is just diagnostics, prints, logging, etc. + + // todo - move or double-buffer all of this timing logic to avoid idling the GPU at this point! + float time_elapsed_ms; + cudaCheck(cudaEventElapsedTime(&time_elapsed_ms, start, end)); + size_t tokens_processed = (size_t)multi_gpu_config.num_processes * B * T * grad_accum_steps; + float tokens_per_second = tokens_processed / time_elapsed_ms * 1000.0f; + float bias_corrected_ema_tokens_per_second = tokens_per_second; // by default set to non-ema version + if (step > 0) { // consider the first batch to be a warmup (e.g. cuBLAS/cuDNN initialisation) + total_sum_iteration_time_s += time_elapsed_ms / 1000.0f; + // smooth out the tok/s with an exponential moving average, and bias correct just like in AdamW + ema_tokens_per_second = 0.95f * ema_tokens_per_second + 0.05f * tokens_per_second; + bias_corrected_ema_tokens_per_second = ema_tokens_per_second / (1.0f - powf(0.95f, step)); + } + float mfu = gpt2_estimate_mfu(&model, B * T * grad_accum_steps, time_elapsed_ms / 1000.0f); + printf0("step %4d/%d | loss %7.6f (%+.2fz)| norm %6.4f (%+.2fz)| lr %.2e | %.2f ms | %.1f%% bf16 MFU | %.0f tok/s\n", + step + 1, train_num_batches, model.mean_loss, zloss, grad_norm, zgrad, step_learning_rate, + time_elapsed_ms, 100*mfu, bias_corrected_ema_tokens_per_second); + if(log_gpu_every > 0 && (step + 1) % log_gpu_every == 0) { + GPUUtilInfo gpu_info = get_gpu_utilization_info(); + printf0(" compute %2.1f%% | memory: %2.1f%% | fan: %2d%% | %4d MHz / %4d MHz | %3d W / %3d W | %d°C / %d°C | %s\n", + gpu_info.gpu_utilization, gpu_info.mem_utilization, gpu_info.fan, gpu_info.clock, gpu_info.max_clock, gpu_info.power / 1000, gpu_info.power_limit / 1000, + gpu_info.temperature, gpu_info.temp_slowdown, gpu_info.throttle_reason); + } + logger_log_train(&logger, step, model.mean_loss, step_learning_rate, grad_norm); + + // disable the profiler after 3 steps of optimization + if (step == 3) { cudaProfilerStop(); } + } + // add a total average, for optimizations that are only mild improvements (excluding 1st batch as warmup) + printf0("total average iteration time: %f ms\n", total_sum_iteration_time_s / (train_num_batches-1) * 1000); + + // free and destroy everything + cudaCheck(cudaEventDestroy(end)); + cudaCheck(cudaEventDestroy(start)); + if (run_hellaswag) { evalloader_free(&eval_loader); } + dataloader_free(&train_loader); + dataloader_free(&val_loader); + tokenizer_free(&tokenizer); + free(cpu_logits_raw); + free(cpu_logits); + free(gen_tokens); + multi_gpu_config_free(&multi_gpu_config); + gpt2_free(&model); + common_free(model); + return 0; +} +#endif From 01bc4c685a78c52b6aa5c5c12b5482d26a9d1e51 Mon Sep 17 00:00:00 2001 From: Andrej Karpathy Date: Fri, 13 Sep 2024 20:44:10 +0000 Subject: [PATCH 02/63] first set of changes to match up the .py and the .cu version. default hyperparameters, introduce int+float section of header, read the header and EXIT for now --- Makefile | 3 + train_llama3.cu | 207 +++++++++++++++--------------------------------- train_llama3.py | 39 ++++----- 3 files changed, 86 insertions(+), 163 deletions(-) diff --git a/Makefile b/Makefile index 73b83720c..c1f3f2f9d 100644 --- a/Makefile +++ b/Makefile @@ -285,6 +285,9 @@ test_gpt2fp32cu: test_gpt2_fp32.cu profile_gpt2cu: profile_gpt2.cu $(NVCC_CUDNN) $(NVCC) $(NVCC_FLAGS) $(PFLAGS) -lineinfo $^ $(NVCC_LDFLAGS) $(NVCC_INCLUDES) $(NVCC_LDLIBS) $(CUDA_OUTPUT_FILE) +train_llama3cu: train_llama3.cu $(NVCC_CUDNN) + $(NVCC) $(NVCC_FLAGS) $(PFLAGS) $^ $(NVCC_LDFLAGS) $(NVCC_INCLUDES) $(NVCC_LDLIBS) $(CUDA_OUTPUT_FILE) + clean: $(REMOVE_FILES) $(TARGETS) $(REMOVE_BUILD_OBJECT_FILES) diff --git a/train_llama3.cu b/train_llama3.cu index 70d8d0c5a..9e0e995ea 100644 --- a/train_llama3.cu +++ b/train_llama3.cu @@ -89,8 +89,14 @@ typedef struct { int vocab_size; // vocab size, e.g. 50257 int padded_vocab_size; // padded to e.g. %128==0, 50304 int num_layers; // number of layers, e.g. 12 - int num_heads; // number of heads in attention, e.g. 12 + int num_heads; // number of query heads in attention, e.g. 12 + int num_kv_heads; // number of key and value heads in attention, e.g. 4 (<-- new in Llama 3) int channels; // number of channels, e.g. 768 + int multiple_of; // used in feedforward layer sizing, e.g. 1024 (<-- new in Llama 3) + int use_scaled_rope; // whether to use scaled rope + float ffn_dim_multiplier; // multiplier used in feedforward layer, e.g. 1.3 (<-- new in Llama 3) + float norm_eps; // epsilon used in layernorm, e.g. 1e-5 + float rope_theta; // theta used in ROPE attention, e.g. 500000.0 (<-- new in Llama 3) } GPT2Config; // the parameters of the model @@ -467,10 +473,14 @@ void gpt2_build_from_checkpoint(GPT2 *model, const char* checkpoint_path, bool w // read in model from a checkpoint file FILE *model_file = fopenCheck(checkpoint_path, "rb"); - int model_header[256]; - freadCheck(model_header, sizeof(int), 256, model_file); - if (model_header[0] != 20240326) { printf("Bad magic model file\n"); exit(EXIT_FAILURE); } - int version = model_header[1]; + int header_int[256]; // int section of the header + freadCheck(header_int, sizeof(int), 256, model_file); + assert(sizeof(int) == 4); // i think the python export code currently assumes this is int32 + float header_float[256]; // float section of the header + freadCheck(header_float, sizeof(float), 256, model_file); + assert(sizeof(float) == 4); // i think the python export code currently assumes this is float32 + if (header_int[0] != 20240803) { printf("Bad magic model file\n"); exit(EXIT_FAILURE); } + int version = header_int[1]; if (!(version == 3 || version == 5)) { // 3 = fp32, padded vocab // 5 = bf16, padded vocab, layernorms also in bf16 @@ -494,13 +504,44 @@ void gpt2_build_from_checkpoint(GPT2 *model, const char* checkpoint_path, bool w } } - // read in hyperparameters - model->config.max_seq_len = model_header[2]; - model->config.vocab_size = model_header[3]; - model->config.num_layers = model_header[4]; - model->config.num_heads = model_header[5]; - model->config.channels = model_header[6]; - model->config.padded_vocab_size = model_header[7]; + // read in hyperparameters from the header + // first the integer section + model->config.max_seq_len = header_int[2]; + model->config.vocab_size = header_int[3]; + model->config.padded_vocab_size = model->config.vocab_size; // in Llama 3 there is no need for padding + model->config.num_layers = header_int[4]; + model->config.num_heads = header_int[5]; + model->config.num_kv_heads = header_int[6]; + model->config.channels = header_int[7]; + model->config.multiple_of = header_int[8]; + model->config.use_scaled_rope = header_int[9]; + int major_version = header_int[10]; // currently unused, e.g. 3 + int minor_version = header_int[11]; // currently unused, e.g. 1 (so Llama 3.1) + // now the float section + model->config.ffn_dim_multiplier = header_float[0]; + model->config.norm_eps = header_float[1]; + model->config.rope_theta = header_float[2]; + + // ------------------------------------------------------------------------ + // TODO TAKE OUT ---------------------------------------------------------- + // Debugging: print all of the values above to check visually and EXIT + printf("CHECK:\n"); + printf("max_seq_len: %d\n", model->config.max_seq_len); + printf("vocab_size: %d\n", model->config.vocab_size); + printf("padded_vocab_size: %d\n", model->config.padded_vocab_size); + printf("num_layers: %d\n", model->config.num_layers); + printf("num_heads: %d\n", model->config.num_heads); + printf("num_kv_heads: %d\n", model->config.num_kv_heads); + printf("channels: %d\n", model->config.channels); + printf("multiple_of: %d\n", model->config.multiple_of); + printf("use_scaled_rope: %d\n", model->config.use_scaled_rope); + printf("major version: %d\n", major_version); + printf("minor version: %d\n", minor_version); + printf("ffn_dim_multiplier: %f\n", model->config.ffn_dim_multiplier); + printf("norm_eps: %f\n", model->config.norm_eps); + printf("rope_theta: %f\n", model->config.rope_theta); + exit(0); + // ------------------------------------------------------------------------ // allocate memory for the model parameters gpt2_allocate_weights(model); @@ -516,131 +557,6 @@ void gpt2_build_from_checkpoint(GPT2 *model, const char* checkpoint_path, bool w cudaCheck(cudaDeviceSynchronize()); } -void gpt2_set_hyperparameters(GPT2Config* config, const char* depth_str) { - int depth = atoi(depth_str); - assert(depth > 0); // atoi returns 0 if not a number - int channels, num_heads; - if (depth == 6) { channels = 384; num_heads = 6; } // (unofficial) gpt2-tiny (30M) - else if (depth == 12) { channels = 768; num_heads = 12; } // gpt2 (124M) - else if (depth == 24) { channels = 1024; num_heads = 16; } // gpt2-medium (350M) - else if (depth == 36) { channels = 1280; num_heads = 20; } // gpt2-large (774M) - else if (depth == 48) { channels = 1600; num_heads = 25; } // gpt2-xl (1558M) - else if (depth == 60) { channels = 1920; num_heads = 30; } // (unofficial) 2.7B - else if (depth == 72) { channels = 2880; num_heads = 30; } // (unofficial) 7.3B - else if (depth == 84) { channels = 3456; num_heads = 36; } // (unofficial) 12.2B - else { fprintf(stderr, "Unsupported GPT-2 depth: %d\n", depth); exit(EXIT_FAILURE); } - config->num_layers = depth; - config->channels = channels; - config->num_heads = num_heads; - config->max_seq_len = 1024; -} - -void gpt3_set_hyperparameters(GPT2Config* config, const char* channels_str) { - // we use channels instead of depth for GPT-3 because GPT-3 model depths are not one-to-one - // note that our models are not necessarily identical to GPT-3 because - // we use dense attention, not the alternating dense/banded attention of GPT-3 - int channels = atoi(channels_str); - assert(channels > 0); // atoi returns 0 if not a number - int depth, head_size; - if (channels == 384) { depth = 6; head_size = 64; } // (unofficial) gpt3-tiny (31M) - else if (channels == 768) { depth = 12; head_size = 64; } // gpt3-small (125M) - else if (channels == 1024) { depth = 24; head_size = 64; } // gpt3-medium (350M) - else if (channels == 1536) { depth = 24; head_size = 96; } // gpt3-large (760M) - else if (channels == 2048) { depth = 24; head_size = 128; } // gpt3-xl (1.3B) [heads fixed] - else if (channels == 2560) { depth = 32; head_size = 80; } // gpt3-2.7B - else if (channels == 4096) { depth = 32; head_size = 128; } // gpt3-6.7B - else if (channels == 5140) { depth = 40; head_size = 128; } // gpt3-13B - else if (channels == 12288) { depth = 96; head_size = 128; } // gpt3 (175B) - else { fprintf(stderr, "Unsupported GPT-3 channels: %d\n", channels); exit(EXIT_FAILURE); } - assert(channels % head_size == 0); - config->num_layers = depth; - config->channels = channels; - config->num_heads = channels / head_size; - config->max_seq_len = 2048; // NOTE: GPT-3 uses context length of 2048 tokens, up from 1024 in GPT-2 -} - -void gpt_build_from_descriptor(GPT2 *model, const char* descriptor) { - // The model descriptor can be: - // - legacy format "dX", where X is number, e.g. "d12". This creates GPT-2 model with 12 layers. - // - new explicit format "gpt2:dX", same as above, e.g. "gpt2:d48" for GPT-2 with 48 layers. - // - "gpt3:cX", where X is now the channel count, e.g. "gpt3:c768" is the smallest GPT-3 model. - - // check the valid prexies and dispatch to the right setup function - assert(descriptor != NULL); - size_t len = strlen(descriptor); - if (len > 1 && descriptor[0] == 'd') { - gpt2_set_hyperparameters(&model->config, descriptor + 1); // pass along the depth str without the 'd' - } else if (len > 6 && strncmp(descriptor, "gpt2:d", 6) == 0) { - gpt2_set_hyperparameters(&model->config, descriptor + 6); // pass along the depth str without the 'gpt2:d' - } else if (len > 6 && strncmp(descriptor, "gpt3:c", 6) == 0) { - gpt3_set_hyperparameters(&model->config, descriptor + 6); // pass along the channels str without the 'gpt3:c' - } else { - fprintf(stderr, "Unsupported model descriptor: %s\n", descriptor); exit(EXIT_FAILURE); - } - - // both GPT-2 and GPT-3 use the same tokenizer with 50257 tokens - model->config.vocab_size = 50257; - model->config.padded_vocab_size = 50304; // padded to 128 for CUDA kernel efficiency - - gpt2_allocate_weights(model); - - // allocate and random init the memory for all the parameters with GPT-2 schema - // weights ~N(0, 0.02), biases 0, c_proj weights ~N(0, 0.02/(2*L)**0.5) - // NOTE: assuming all parameters are of the type floatX, could be relaxed later - mt19937_state init_rng; - manual_seed(&init_rng, 42); - floatX* params_memory_cpu = (floatX*)mallocCheck(model->num_parameters_bytes); - memset(params_memory_cpu, 0, model->num_parameters_bytes); - // fill in all the weights with random values - float residual_scale = 1.0f / sqrtf(2.0f * model->config.num_layers); - // we have to init all these tensors exactly in the order that PyTorch initializes them - // so that we can match them up and get correctness and exactly the same initial conditions - size_t L = model->config.num_layers; - size_t offset = 0; - for (int l = 0; l < L; l++) { - offset = 0; - for (int i = 0; i < NUM_PARAMETER_TENSORS; i++) { - // the layernorm parameters are all initialized to 1 - if (l == 0 && (i == 2 || i == 8 || i == 14)) { // only at l = 0 to init these just once - for (size_t j = 0; j < model->param_elements[i]; j++) { - params_memory_cpu[offset + j] = 1.0f; - } - } - // weights tensors are handled here - if ((l == 0 && (i == 0 || i == 1)) // only at l = 0, init the wte and wpe tensors - || i == 4 || i == 6 || i == 10 || i == 12) { - size_t n = model->param_elements[i]; - size_t layer_offset = 0; - if (i == 0) { - // for wte tensor (padded vocab) override to init V instead of Vp rows - n = model->config.vocab_size * model->config.channels; - } - if (i == 4 || i == 6 || i == 10 || i == 12) { - // weight tensors, we are only initializing layer l - assert(n % L == 0); - n = n / L; - layer_offset = l * n; - } - // in GPT-2, the projections back into the residual stream are additionally - // scaled by 1/sqrt(2*L) for training stability - float scale = (i == 6 || i == 12) ? 0.02f * residual_scale : 0.02f; - // okay let's draw the random numbers and write them - float *fp32_buffer = (float*)mallocCheck(n * sizeof(float)); - normal_(fp32_buffer, n, 0.0f, scale, &init_rng); - for (size_t j = 0; j < n; j++) { - params_memory_cpu[offset + layer_offset + j] = (floatX)fp32_buffer[j]; - } - free(fp32_buffer); - } - offset += model->param_elements[i]; - } - } - - // copy them to GPU - cudaCheck(cudaMemcpy(model->params_memory, params_memory_cpu, model->num_parameters_bytes, cudaMemcpyHostToDevice)); - free(params_memory_cpu); -} - // propagate inputs through the network to produce logits. // right now, this function is fully synchronous with the host void gpt2_forward(GPT2 *model, const int* inputs, size_t B, size_t T) { @@ -1420,7 +1336,7 @@ int main(int argc, char *argv[]) { // read in the (optional) command line arguments const char* train_data_pattern = "dev/data/tinyshakespeare/tiny_shakespeare_train.bin"; const char* val_data_pattern = "dev/data/tinyshakespeare/tiny_shakespeare_val.bin"; - const char* load_filename = "gpt2_124M_bf16.bin"; // bf16 weights of the model + const char* load_filename = "llama3.1_8B_bf16.bin"; // bf16 weights of the Llama 3.1 8B model const char* lr_scheduler_type = "cosine"; const char* output_log_dir = NULL; int checkpoint_every = 0; // write checkpoints every how many steps? @@ -1428,9 +1344,9 @@ int main(int argc, char *argv[]) { int major_checkpoint_every = 0; // major checkpoints never get deleted when maintaining history int resume = 0; // resume the optimization, if one is found inside output_log_dir? int B = 4; // batch size - int T = 1024; // sequence length max + int T = 64; // sequence length max int total_batch_size = -1; // will be calculated down below later, if not provided - float learning_rate = 3e-4f; + float learning_rate = 1e-5f; int log_gpu_every = -1; int warmup_iterations = 0; float final_learning_rate_frac = 1.0f; // final fraction of learning rate, at end of training @@ -1441,8 +1357,8 @@ int main(int argc, char *argv[]) { int val_max_steps = 20; // how many batches max do we eval for validation loss? int sample_every = 20; // every how many steps to do inference? int genT = 64; // number of steps of inference we will do - int overfit_single_batch = 0; // useful for debugging, 1 = only load a single data batch once - int max_steps = -1; + int overfit_single_batch = 1; // useful for debugging, 1 = only load a single data batch once + int max_steps = 10; int override_enable_tf32 = 1; int use_master_weights = 1; int gelu_fusion = -1; // 0 = none, 1 = forward, 2 = forward+backward (-1 => per-GPU default) @@ -1580,9 +1496,10 @@ int main(int argc, char *argv[]) { // otherwise, if this is a .bin file, we assume it's a model, let's init from it gpt2_build_from_checkpoint(&model, load_filename); } else { - // if it's not .bin, it could be a "special descriptor". This descriptor is used to - // construct GPT-2 / GPT-3 models in a convenient format. See the function for docs. - gpt_build_from_descriptor(&model, load_filename); + // For Llama 3.1 we currently demand a .bin file to load the model from, and + // initializing from scratch is currently not supported (but can be added later) + printf0("Error: Llama 3 cannot be initialized from scratch right now\n"); + exit(EXIT_FAILURE); } model.use_master_weights = use_master_weights; diff --git a/train_llama3.py b/train_llama3.py index f9daafde0..8b59d63dd 100644 --- a/train_llama3.py +++ b/train_llama3.py @@ -875,28 +875,31 @@ def write_model(model, filename, dtype): "float32": 3, # 3: all tensors are fp32 "bfloat16": 5, # 5: all tensors are bf16 }[dtype] - header = torch.zeros(256, dtype=torch.int32) - header[0] = 20240803 # magic - header[1] = version # checkpoint version - header[2] = model.config.block_size - header[3] = model.config.vocab_size - header[4] = model.config.n_layer - header[5] = model.config.n_head - header[6] = model.config.n_kv_head - header[7] = model.config.n_embd - header[8] = model.config.ffn_dim_multiplier - header[9] = model.config.multiple_of - header[10] = model.config.norm_eps - header[11] = model.config.rope_theta - header[12] = model.config.use_scaled_rope - header[13] = model.config.max_gen_batch_size - header[14] = int(model.config.version.split('.')[0]) # major version - header[15] = int(model.config.version.split('.')[1]) # minor version + # integer section of the header + header_int = torch.zeros(256, dtype=torch.int32) + header_int[0] = 20240803 # magic + header_int[1] = version # checkpoint version + header_int[2] = model.config.block_size + header_int[3] = model.config.vocab_size + header_int[4] = model.config.n_layer + header_int[5] = model.config.n_head + header_int[6] = model.config.n_kv_head + header_int[7] = model.config.n_embd + header_int[8] = model.config.multiple_of + header_int[9] = int(model.config.use_scaled_rope) + header_int[10] = int(model.config.version.split('.')[0]) # major version + header_int[11] = int(model.config.version.split('.')[1]) # minor version + # float section of the header + header_float = torch.zeros(256, dtype=torch.float32) + header_float[0] = model.config.ffn_dim_multiplier + header_float[1] = model.config.norm_eps + header_float[2] = model.config.rope_theta # 2) the parameters follow the header params = {name: param.cpu() for name, param in model.named_parameters()} # now write to file with open(filename, "wb") as file: - file.write(header.numpy().tobytes()) # header + file.write(header_int.numpy().tobytes()) # int header + file.write(header_float.numpy().tobytes()) # float header write_tensors(params, model.config.n_layer, file, dtype) # params print(f"wrote {filename}") From b883560d264a173f6853d8c4097d512be3d51165 Mon Sep 17 00:00:00 2001 From: Andrej Karpathy Date: Fri, 13 Sep 2024 22:47:30 +0000 Subject: [PATCH 03/63] change the export code of Llama 3 to be very GPT-2 friendly, using a combination of 3 hacks. this will make it so that we have to change very little code on the C side --- train_llama3.py | 40 ++++++++++++++++++++++++++++++++++++++-- 1 file changed, 38 insertions(+), 2 deletions(-) diff --git a/train_llama3.py b/train_llama3.py index 8b59d63dd..7ed0cdf7b 100644 --- a/train_llama3.py +++ b/train_llama3.py @@ -847,25 +847,61 @@ def write_bf16(tensor, file): def write_tensors(model_tensors, L, file, dtype): # writes LLaMA 3 model's weights to a binary file + # things get a bit more complicated though: + # 1) We want to maintain the ability to finetune just the biases in the C code + # and also GPT-2 supported biases and we want to touch as little code as possible. + # => We will generate biases of all zeros and write them here. It's very little data. + # 2) We want to exactly preserve the GPT-2 code paths, so we can't have SwiGLU using two + # separate nn.Linear layers c_fc and c_fc2. We will merge them into a single c_fc layer. + # Then later in the C code, we do pointer arithmetic to recover them fully internal to + # the SwiGLU layer + # 3) Llama 3 does not use position embeddings table so we have to remove it. AT THE SAME TIME, + # and, very conveniently, Llama 3 does not share the output projection weights with the + # token embeddings table, so we have to add it. Well instead of removing and adding, we + # are going to write the output projection weights into the slot previously used for the + # position embeddings table. Everyone is happy, very little code is changed from GPT-2. assert dtype in {"float32", "bfloat16"} write_fun = write_fp32 if dtype == "float32" else write_bf16 write_fun(model_tensors["transformer.wte.weight"], file) # (V, C) + write_fun(model_tensors["lm_head.weight"], file) # (V, C) # <--- hack (3) here! for i in range(L): # (L, C) write_fun(model_tensors[f"transformer.h.{i}.ln_1.weight"], file) + for i in range(L): # (L, C) + # see hack (1) above for these + # yes i know this is inefficient and dumb i'm just matching the train_gpt2.py code format + write_fun(torch.zeros_like(model_tensors[f"transformer.h.{i}.ln_1.weight"]), file) for i in range(L): # (L, 3C, C) write_fun(model_tensors[f"transformer.h.{i}.attn.c_attn.weight"], file) + for i in range(L): # (L, 3C) + w = model_tensors[f"transformer.h.{i}.attn.c_attn.weight"] + write_fun(torch.zeros(w.size(0), dtype=w.dtype), file) for i in range(L): # (L, C, C) write_fun(model_tensors[f"transformer.h.{i}.attn.c_proj.weight"], file) + for i in range(L): # (L, C) + w = model_tensors[f"transformer.h.{i}.attn.c_proj.weight"] + write_fun(torch.zeros(w.size(0), dtype=w.dtype), file) for i in range(L): # (L, C) write_fun(model_tensors[f"transformer.h.{i}.ln_2.weight"], file) + for i in range(L): # (L, C) + write_fun(torch.zeros_like(model_tensors[f"transformer.h.{i}.ln_2.weight"]), file) + # now for hack (2) here... inline model surgery to concat c_fc and c_fc2 + # ------------------------------------------- for i in range(L): # (L, 4C, C) + # simply write the two weights in sequence write_fun(model_tensors[f"transformer.h.{i}.mlp.c_fc.weight"], file) - for i in range(L): # (L, 4C, C) write_fun(model_tensors[f"transformer.h.{i}.mlp.c_fc2.weight"], file) + for i in range(L): # (L, 4C) + w1 = model_tensors[f"transformer.h.{i}.mlp.c_fc.weight"] + w2 = model_tensors[f"transformer.h.{i}.mlp.c_fc2.weight"] + write_fun(torch.zeros(w1.size(0) + w2.size(0), dtype=w1.dtype), file) + # ------------------------------------------- for i in range(L): # (L, C, 4C) write_fun(model_tensors[f"transformer.h.{i}.mlp.c_proj.weight"], file) + for i in range(L): # (L, C) + w = model_tensors[f"transformer.h.{i}.mlp.c_proj.weight"] + write_fun(torch.zeros(w.size(0), dtype=w.dtype), file) write_fun(model_tensors["transformer.ln_f.weight"], file) # (C, ) - write_fun(model_tensors["lm_head.weight"], file) # (V, C) + write_fun(torch.zeros_like(model_tensors["transformer.ln_f.weight"]), file) # (C, ) def write_model(model, filename, dtype): # everything we need to instantiate the model From 88663086fb23fe0bef5b773ec3c2ba8ac9031bf8 Mon Sep 17 00:00:00 2001 From: Andrej Karpathy Date: Mon, 16 Sep 2024 17:45:28 +0000 Subject: [PATCH 04/63] adapt the sizes of all the parameter tensors and load them from file. so now we are loading all the Llama 3 weights. I verified that the sizes of all the tensors agree with python, and the total number of parameters --- train_llama3.cu | 48 ++++++++++++++++++++++++++++++++++++++---------- 1 file changed, 38 insertions(+), 10 deletions(-) diff --git a/train_llama3.cu b/train_llama3.cu index 9e0e995ea..4acb7e6b8 100644 --- a/train_llama3.cu +++ b/train_llama3.cu @@ -122,27 +122,43 @@ typedef struct { static_assert(sizeof(ParameterTensors) == NUM_PARAMETER_TENSORS * sizeof(void*), "Inconsistent sizes!"); void fill_in_parameter_sizes(size_t* param_sizes, size_t* param_sizeof, GPT2Config config) { + // see train_llama3.py write_tensors() function for detailed docs of some of the trickery here + // trick 1: all biases are still present but set to zero + // trick 2: the SwiGLU weights are "packed" into one, concatenated + // trick 3: the positional embedding is replaced with the final classifier layer weights size_t Vp = config.padded_vocab_size; size_t C = config.channels; - size_t maxT = config.max_seq_len; size_t L = config.num_layers; + // calculation following the .py code inside CausalSelfAttention + // we have to calculate the number of channels in the QKV projection + size_t n_head = config.num_heads; + size_t n_kn_head = config.num_kv_heads; + size_t hd = C / n_head; // head dimension + size_t qkv_channels = (n_head + 2*n_kn_head) * hd; // Q, K, V channels + // calculation following the .py code inside MLP + // we have to calculate the number of channels in the SwiGLU projections c_fc + c_fc2 + size_t hidden_dim = 4 * C; + hidden_dim = (2 * hidden_dim) / 3; + hidden_dim = config.ffn_dim_multiplier * hidden_dim; + hidden_dim = config.multiple_of * ((hidden_dim + config.multiple_of - 1) / config.multiple_of); + size_t ffn_channels = hidden_dim * 2; // c_fc + c_fc2 concatenated + // now populate the parameter sizes param_sizes[0] = Vp * C; // wte - param_sizes[1] = maxT * C; // wpe + param_sizes[1] = Vp * C; // (3) lm_head (final classifier layer weights) param_sizes[2] = L * C; // ln1w - param_sizes[3] = L * C; // ln1b - param_sizes[4] = L * (3 * C) * C; // qkvw - param_sizes[5] = L * (3 * C); // qkvb + param_sizes[3] = L * C; // ln1b; (1) all biases are zero it's ok + param_sizes[4] = L * (qkv_channels) * C; // qkvw + param_sizes[5] = L * (qkv_channels); // qkvb param_sizes[6] = L * C * C; // attprojw param_sizes[7] = L * C; // attprojb param_sizes[8] = L * C; // ln2w param_sizes[9] = L * C; // ln2b - param_sizes[10] = L * (4 * C) * C; // fcw - param_sizes[11] = L * (4 * C); // fcb - param_sizes[12] = L * C * (4 * C); // fcprojw + param_sizes[10] = L * ffn_channels * C; // fcw; (2) this is twice the size + param_sizes[11] = L * ffn_channels; // fcb + param_sizes[12] = L * C * hidden_dim; // fcprojw param_sizes[13] = L * C; // fcprojb param_sizes[14] = C; // lnfw param_sizes[15] = C; // lnfb - // populate the parameter sizes in bytes (all the same for now, keeping for future use) for (int i = 0; i < NUM_PARAMETER_TENSORS; i++) { param_sizeof[i] = sizeof(floatX); @@ -365,6 +381,16 @@ void gpt2_allocate_weights(GPT2 *model) { model->num_parameters += model->param_elements[i]; model->num_parameters_bytes += model->param_elements[i] * model->param_sizeof[i]; } + + // TODO TAKE OUT ---------------------------------------------------------- + // DEBUGGING: print out the sizes of the parameters + for (int i = 0; i < NUM_PARAMETER_TENSORS; i++) { + printf("param_elements[%d] = %zu\n", i, model->param_elements[i]); + } + printf("num_parameters = %zu\n", model->num_parameters); + printf("num_parameters_bytes = %zu\n", model->num_parameters_bytes); + // ------------------------------------------------------------------------ + // create memory for model parameters on the device assert(model->params_memory == nullptr); model->params_memory = malloc_and_point_parameters(&model->params, model->param_elements, model->param_sizeof); @@ -540,7 +566,6 @@ void gpt2_build_from_checkpoint(GPT2 *model, const char* checkpoint_path, bool w printf("ffn_dim_multiplier: %f\n", model->config.ffn_dim_multiplier); printf("norm_eps: %f\n", model->config.norm_eps); printf("rope_theta: %f\n", model->config.rope_theta); - exit(0); // ------------------------------------------------------------------------ // allocate memory for the model parameters @@ -1515,6 +1540,9 @@ int main(int argc, char *argv[]) { printf0("| num_parameters | %-50zu |\n", model.num_parameters); printf0("+-----------------------+----------------------------------------------------+\n"); + // DEBUGGING: we only work until this point right now, so exit here + exit(0); + // build DataLoaders for both train and val int permute_train_loader = (overfit_single_batch == 1) ? 0 : 1; DataLoader train_loader, val_loader; From 45026f6eadfb919a5212b015c1dc56b6c7c9f0d7 Mon Sep 17 00:00:00 2001 From: Andrej Karpathy Date: Mon, 16 Sep 2024 19:43:44 +0000 Subject: [PATCH 05/63] make llama3cu phony --- Makefile | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Makefile b/Makefile index c1f3f2f9d..dba457ce8 100644 --- a/Makefile +++ b/Makefile @@ -244,7 +244,7 @@ else endif # PHONY means these targets will always be executed -.PHONY: all train_gpt2 test_gpt2 train_gpt2cu test_gpt2cu train_gpt2fp32cu test_gpt2fp32cu profile_gpt2cu +.PHONY: all train_llama3cu train_gpt2 test_gpt2 train_gpt2cu test_gpt2cu train_gpt2fp32cu test_gpt2fp32cu profile_gpt2cu # Add targets TARGETS = train_gpt2 test_gpt2 From 77e1d7afda1aaba823c10da65a5034644e5971b8 Mon Sep 17 00:00:00 2001 From: Andrej Karpathy Date: Mon, 16 Sep 2024 19:44:33 +0000 Subject: [PATCH 06/63] add support for dataloader to serve uint32_t tokens, as necessary in Llama 3 --- llmc/dataloader.h | 61 +++++++++++++++++++++++++++++++---------------- train_llama3.cu | 29 +++++++--------------- 2 files changed, 48 insertions(+), 42 deletions(-) diff --git a/llmc/dataloader.h b/llmc/dataloader.h index ad5829d06..7784bbabb 100644 --- a/llmc/dataloader.h +++ b/llmc/dataloader.h @@ -36,6 +36,7 @@ typedef struct { size_t T; size_t num_tokens; // total number of tokens size_t shard_num_samples; // total number of samples in the current shard per process + size_t token_dtype; // sizeof(uint16_t) (GPT-2) or sizeof(uint32_t) (Llama 3) // shards and current position glob_t glob_result; // stores the result of glob, for all shards we want to iterate size_t current_shard_idx; // the current shard we are reading from @@ -43,7 +44,7 @@ typedef struct { // file handle FILE* tokens_file; // data buffers - uint16_t* buffer; // we fread data from file into this buffer + void* buffer; // we fread data from file into this buffer int* inputs; // input tokens into transformer int* targets; // target tokens for the transformer // random shuffle related variables @@ -59,6 +60,7 @@ typedef struct { } DataLoader; int64_t dataloader_load_shard_(DataLoader *loader, int shard_index) { + // re-map the shard index via the generated permutation, if data shuffling if (loader->should_shuffle) { shard_index = loader->shard_indices[shard_index]; } @@ -72,13 +74,30 @@ int64_t dataloader_load_shard_(DataLoader *loader, int shard_index) { // validate the header int header[HEADER_SIZE]; freadCheck(header, sizeof(int), HEADER_SIZE, loader->tokens_file); - if (header[0] != 20240520) { + int magic = header[0]; + int gpt2_datafile_magic = 20240520; + int llama3_datafile_magic = 20240801; + if (!(magic == gpt2_datafile_magic || magic == llama3_datafile_magic)) { printf("Bad magic in the data file\n"); printf("---> HINT: Are you passing in a correct file?\n"); printf("---> HINT: The data encoding may have changed, re-run data prepro or refer again to README.\n"); exit(EXIT_FAILURE); } - if (header[1] != 1) { printf("Bad version in data file\n"); exit(EXIT_FAILURE); } + int version = header[1]; + if (magic == gpt2_datafile_magic && version != 1) { printf("Bad version in data file\n"); exit(EXIT_FAILURE); } + if (magic == llama3_datafile_magic && version != 7) { printf("Bad version in data file\n"); exit(EXIT_FAILURE); } + // token dtype related logic + size_t token_dtype = (magic == gpt2_datafile_magic) ? sizeof(uint16_t) : sizeof(uint32_t); + if (loader->token_dtype == 0) { + // this is the first data shard; set the token dtype and some helper variables + loader->token_dtype = token_dtype; + loader->total_batch_size_bytes = ((loader->num_processes * (loader->B * loader->T)) * loader->token_dtype); + loader->local_batch_offset_bytes = loader->process_rank * loader->B * loader->T * loader->token_dtype; + } else { + // we expect consistency across shards + assert(loader->token_dtype == token_dtype); + } + // load the tokens int64_t ntok = header[2]; // number of tokens in the file assert(ntok > 0); // we expect some tokens in the file. this should never trip, right? // determine the file size and make sure it is consistent with the number of tokens @@ -86,13 +105,13 @@ int64_t dataloader_load_shard_(DataLoader *loader, int shard_index) { loader->file_size_bytes = ftell(loader->tokens_file); // read the offset, i.e. file size fseekCheck(loader->tokens_file, 0, SEEK_SET); // seek back to the beginning // we expect ntok in the file to be consistent with filesize, assert that is the case - int64_t expected_file_size = HEADER_SIZE * sizeof(int) + ntok * sizeof(uint16_t); + int64_t expected_file_size = HEADER_SIZE * sizeof(int) + ntok * loader->token_dtype; if (loader->file_size_bytes != expected_file_size) { printf("Error: file size is not as expected\n"); exit(EXIT_FAILURE); } - // -1 uint16_t due to us taking B*T+1 tokens but moving by B*T tokens - loader->shard_num_samples = (ntok * sizeof(uint16_t) - sizeof(uint16_t)) / loader->total_batch_size_bytes; + // -1 token due to us taking B*T+1 tokens but moving by B*T tokens + loader->shard_num_samples = (ntok * loader->token_dtype - loader->token_dtype) / loader->total_batch_size_bytes; return ntok; } @@ -153,8 +172,7 @@ void dataloader_init(DataLoader *loader, loader->tokens_file = NULL; loader->should_shuffle = should_shuffle; loader->header_bytes = HEADER_SIZE * sizeof(int); - loader->total_batch_size_bytes = ((loader->num_processes * (loader->B * loader->T)) * sizeof(uint16_t)); - loader->local_batch_offset_bytes = loader->process_rank * loader->B * loader->T * sizeof(uint16_t); + loader->token_dtype = 0; // set to 0 to indicate that it is not yet set. it will be on first shard load // glob to get the list of files matching the pattern, these are our data shards int glob_status = glob(filename_pattern, 0, NULL, &loader->glob_result); @@ -181,17 +199,13 @@ void dataloader_init(DataLoader *loader, int64_t ntok_total = 0; for (int shard_index = 0; shard_index < loader->glob_result.gl_pathc; shard_index++) { int64_t shard_ntok = dataloader_load_shard_(loader, shard_index); - // we need at least one batch/shard, the way things are written right now. - // can be relaxed a lot later. + // we need at least one batch/shard, the way things are written right now, can be relaxed later assert(shard_ntok >= (int64_t) (num_processes * B * T + 1)); ntok_total += shard_ntok; } - // debugging prints - // printf("DataLoader: filename_pattern: %s\n", filename_pattern); - // printf("DataLoader: Found %ld tokens across %zu shards\n", ntok_total, loader->glob_result.gl_pathc); // allocate all the space we'll need - loader->buffer = (uint16_t*)mallocCheck((B * T + 1) * sizeof(uint16_t)); + loader->buffer = mallocCheck((B * T + 1) * loader->token_dtype); loader->inputs = (int*)mallocCheck(B * T * sizeof(int)); loader->targets = (int*)mallocCheck(B * T * sizeof(int)); loader->num_tokens = ntok_total; @@ -200,22 +214,28 @@ void dataloader_init(DataLoader *loader, dataloader_reset(loader); } +typedef int (*access_func_t)(const void*, size_t); +int access_uint16(const void* buffer, size_t i) { return (int)((uint16_t*)buffer)[i]; } +int access_uint32(const void* buffer, size_t i) { return (int)((uint32_t*)buffer)[i]; } + void dataloader_load_batch(DataLoader* loader) { assert(!loader->should_shuffle || (loader->should_shuffle && loader->intra_shard_indices != NULL)); assert(loader->current_sample_idx < loader->shard_num_samples); size_t idx = loader->should_shuffle ? loader->intra_shard_indices[loader->current_sample_idx] : loader->current_sample_idx; size_t global_batch_offset_bytes = idx * loader->total_batch_size_bytes; int64_t current_offset = loader->header_bytes + global_batch_offset_bytes + loader->local_batch_offset_bytes; - size_t B = loader->B; size_t T = loader->T; - // read B*T+1 uint16_t tokens from the file into buffer + // read B*T+1 tokens from the file into buffer fseekCheck(loader->tokens_file, (int) current_offset, SEEK_SET); - freadCheck(loader->buffer, sizeof(uint16_t), B*T+1, loader->tokens_file); + freadCheck(loader->buffer, loader->token_dtype, B*T+1, loader->tokens_file); + // depending on the dtype we have to access buffer differently + assert(loader->token_dtype == sizeof(uint16_t) || loader->token_dtype == sizeof(uint32_t)); + access_func_t access_func = (loader->token_dtype == sizeof(uint16_t)) ? access_uint16 : access_uint32; // decode the buffer into inputs and targets (cast to int) - for (int i = 0; i < B*T; i++) { - loader->inputs[i] = (int)loader->buffer[i]; - loader->targets[i] = (int)loader->buffer[i+1]; + for (size_t i = 0; i < B*T; i++) { + loader->inputs[i] = access_func(loader->buffer, i); + loader->targets[i] = access_func(loader->buffer, i + 1); } } @@ -228,7 +248,6 @@ void dataloader_next_batch(DataLoader *loader) { loader->current_sample_idx += 1; } - void dataloader_resume(DataLoader *loader, size_t current_shard_idx, size_t current_sample_idx) { // used during model resumption (-y 1) flag loader->current_shard_idx = current_shard_idx; diff --git a/train_llama3.cu b/train_llama3.cu index 4acb7e6b8..53e5fe728 100644 --- a/train_llama3.cu +++ b/train_llama3.cu @@ -1539,9 +1539,7 @@ int main(int argc, char *argv[]) { printf0("| channels C | %-50d |\n", model.config.channels); printf0("| num_parameters | %-50zu |\n", model.num_parameters); printf0("+-----------------------+----------------------------------------------------+\n"); - - // DEBUGGING: we only work until this point right now, so exit here - exit(0); + assert(T <= model.config.max_seq_len); // build DataLoaders for both train and val int permute_train_loader = (overfit_single_batch == 1) ? 0 : 1; @@ -1568,6 +1566,9 @@ int main(int argc, char *argv[]) { printf0("| val_num_batches | %-50d |\n", val_num_batches); printf0("+-----------------------+----------------------------------------------------+\n"); + // DEBUGGING: we only work until this point right now, so exit here + exit(0); + // build an EvalLoader for HellaSwag EvalLoader eval_loader; const char* hellaswag_path = "dev/data/hellaswag/hellaswag_val.bin"; @@ -1605,7 +1606,7 @@ int main(int argc, char *argv[]) { // set up the Tokenizer Tokenizer tokenizer; - tokenizer_init(&tokenizer, "gpt2_tokenizer.bin"); + // tokenizer_init(&tokenizer, "gpt2_tokenizer.bin"); // TODO: port tokenizer later from GPT2 -> Llama 3 // set up learning rate scheduler LearningRateScheduler lr_scheduler; @@ -1630,23 +1631,6 @@ int main(int argc, char *argv[]) { init_detector(&loss_outlier_detector); init_detector(&grad_norm_outlier_detector); - // do some checks here before we kick off training - // cross-check the desired sequence length T with the model's max sequence length - if (T < model.config.max_seq_len) { - printf0("!!!!!!!!\n"); - printf0("WARNING:\n"); - printf0("- The training sequence length is: T=%d (set with -t)\n", T); - printf0("- The model's max sequence length is: max_seq_len=%d\n", model.config.max_seq_len); - printf0("You are attempting to train with a sequence length shorter than the model's max.\n"); - printf0("This will lead to unused parameters in the wpe position embedding weights.\n"); - printf0("If you know what you're doing you can ignore this warning.\n"); - printf0("If you're like ???, you are most likely misconfiguring your training run.\n"); - printf0("---> HINT: If you're training GPT-2 use -t 1024. If GPT-3, use -t 2048.\n"); - printf0("!!!!!!!!\n"); - } - // in any case, this must be true or we'd index beyond the model's wpe (position embedding table) - assert(T <= model.config.max_seq_len); - // train cudaEvent_t start, end; cudaCheck(cudaEventCreate(&start)); @@ -1659,6 +1643,8 @@ int main(int argc, char *argv[]) { int last_step = step == train_num_batches; + if(0) { // TODO DELETE; START: IGNORE ALL THIS BLOCK WHILE GETTING STUFF TO WORK + // once in a while estimate the validation loss (all processes collaborate) if (step % val_loss_every == 0 || last_step) { NvtxRange validation_range("validation"); @@ -1756,6 +1742,7 @@ int main(int argc, char *argv[]) { } } resuming = 0; + } // TODO DELETE; END: IGNORE ALL THIS BLOCK WHILE GETTING STUFF TO WORK // bit confusing: we want to make sure to eval and sample on 0th iteration // but also after the very last iteration. so we loop for step <= train_num_batches From 72e6f1ab0b83ab252639949b4880568460a673fa Mon Sep 17 00:00:00 2001 From: Andrej Karpathy Date: Mon, 16 Sep 2024 21:03:01 +0000 Subject: [PATCH 07/63] add new Encoder that does not use positional embeddings, like in llama 3. The activations match after encoding. onwards --- llmc/encoder.cuh | 26 +++++++++++++++++++++++++- train_llama3.cu | 16 ++++++++++++---- train_llama3.py | 7 +++++++ 3 files changed, 44 insertions(+), 5 deletions(-) diff --git a/llmc/encoder.cuh b/llmc/encoder.cuh index 3aa63e175..fbaf56af1 100644 --- a/llmc/encoder.cuh +++ b/llmc/encoder.cuh @@ -43,6 +43,24 @@ __global__ void encoder_forward_kernel3(floatX* out, store128(out_btc, packed_out); } +// same kernel but without the positional encoder +__global__ void encoder_forward_kernel3_nowpe(floatX* out, + const int* inp, const floatX* wte, + int B, int T, int C) { + int idx = (blockIdx.x * blockDim.x + threadIdx.x) * x128::size; + int N = B * T * C; + if (idx >= N) { return; } + int bt = idx / C; + int b = bt / T; + int t = bt % T; + int c = idx % C; + int ix = inp[b * T + t]; + floatX* out_btc = out + b * T * C + t * C + c; + const floatX* wte_ix = wte + ix * C + c; + x128 wte128 = load128cs(wte_ix); + store128(out_btc, wte128); +} + template __global__ void wte_backward_kernel(floatX* dwte, const int4* bucket_info, const int* workload_indices, const floatX* dout, const int* inp, @@ -161,7 +179,13 @@ void encoder_forward(floatX* out, const int block_size = 256; const int N = B * T * C; const int grid_size = CEIL_DIV(N, (int)(block_size * x128::size)); - encoder_forward_kernel3<<>>(out, inp, wte, wpe, B, T, C); + if (wpe == NULL) { + // Llama 3 does not use positional encoder + encoder_forward_kernel3_nowpe<<>>(out, inp, wte, B, T, C); + } else { + // GPT-2 does, so we use the full encoder kernel + encoder_forward_kernel3<<>>(out, inp, wte, wpe, B, T, C); + } cudaCheck(cudaGetLastError()); } diff --git a/train_llama3.cu b/train_llama3.cu index 53e5fe728..9edf5d3f7 100644 --- a/train_llama3.cu +++ b/train_llama3.cu @@ -618,7 +618,18 @@ void gpt2_forward(GPT2 *model, const int* inputs, size_t B, size_t T) { // forward pass ParameterTensors params = model->params; // for brevity ActivationTensors acts = model->acts; - encoder_forward(acts.encoded, model->inputs, params.wte, params.wpe, B, T, C, main_stream); // encoding goes into residual[0] + encoder_forward(acts.encoded, model->inputs, params.wte, NULL, B, T, C, main_stream); // encoding goes into residual[0] + + // ------------------------------------------------------------------------ + // DEBUGGING: we only work until this point right now, so exit here + // transfer the first 32 elements to CPU and print them + floatX* cpu = (floatX*)mallocCheck(32 * sizeof(floatX)); + cudaCheck(cudaMemcpy(cpu, acts.encoded, 32 * sizeof(floatX), cudaMemcpyDeviceToHost)); + for (int i = 0; i < 32; i++) { + printf("cpu[%d] = %f\n", i, (float) cpu[i]); + } + exit(0); + // ------------------------------------------------------------------------ // first layernorm isn't fused layernorm_forward((model->recompute < 2) ? acts.ln1 : acts.lnf, acts.ln1_mean, acts.ln1_rstd, acts.encoded, params.ln1w, params.ln1b, B, T, C, main_stream); @@ -1566,9 +1577,6 @@ int main(int argc, char *argv[]) { printf0("| val_num_batches | %-50d |\n", val_num_batches); printf0("+-----------------------+----------------------------------------------------+\n"); - // DEBUGGING: we only work until this point right now, so exit here - exit(0); - // build an EvalLoader for HellaSwag EvalLoader eval_loader; const char* hellaswag_path = "dev/data/hellaswag/hellaswag_val.bin"; diff --git a/train_llama3.py b/train_llama3.py index 7ed0cdf7b..ca324cd8a 100644 --- a/train_llama3.py +++ b/train_llama3.py @@ -300,6 +300,13 @@ def forward(self, idx, targets=None, return_logits=True, start_pos=0): x = self.transformer.wte(idx) # token embeddings of shape (b, t, n_embd) freqs_cis = self.freqs_cis[start_pos:start_pos+t] + # --------------------------------------------------------------------- + # DEBUGGING: print first 32 elements of x + for i in range(32): + print("acts[{}]: {}".format(i, x.view(-1)[i].item())) + breakpoint() + # --------------------------------------------------------------------- + mask = torch.triu(torch.ones((t, t), device=next(self.parameters()).device, dtype=torch.bool), diagonal=1) for i, block in enumerate(self.transformer.h): From 234de31fdf8306bf8cfcb1f550b4587e16fa4218 Mon Sep 17 00:00:00 2001 From: Andrej Karpathy Date: Mon, 16 Sep 2024 21:43:15 +0000 Subject: [PATCH 08/63] introduce rmsnorm, unfused, forward --- llmc/layernorm.cuh | 84 ++++++++++++++++++++++++++++++++++++++++++++++ train_llama3.cu | 12 ++++--- train_llama3.py | 16 ++++----- 3 files changed, 99 insertions(+), 13 deletions(-) diff --git a/llmc/layernorm.cuh b/llmc/layernorm.cuh index 9777d0658..1387b11ab 100644 --- a/llmc/layernorm.cuh +++ b/llmc/layernorm.cuh @@ -139,6 +139,66 @@ __global__ void layernorm_forward_kernel6(floatX* __restrict__ out, float* __res } } +__global__ void rmsnorm_forward_kernel6(floatX* __restrict__ out, float* __restrict__ rms, + const floatX* __restrict__ inp, const floatX* __restrict__ weight, int N, int C) { + // this kernel is a simplified version of layernorm_forward_kernel6 + assert(blockDim.x == WARP_SIZE); + + // load weights into shared memory + // do this before we allow any threads to exit! + extern __shared__ char* params[]; + // load128/store128 sometimes generated multiple instructions when the types here were floatX*, so + // let's keep everything as x128 + x128* s_weight = reinterpret_cast(params); + x128* s_in = reinterpret_cast(params) + ((1 + threadIdx.y) * C / x128::size); + + int sidx = (threadIdx.x + WARP_SIZE * threadIdx.y) * x128::size; + for(int i = sidx; i < C; i += blockDim.y * WARP_SIZE * x128::size) { + s_weight[i/x128::size] = load128(weight + i); + } + __syncthreads(); + + int idx = blockIdx.x * blockDim.y + threadIdx.y; + if(idx >= N) { return; } // guard + + // adjust pointers to current token + inp += idx * C; + out += idx * C; + + const float eps = 1e-5f; + float acc = 0.f; + + for(int c = threadIdx.x * x128::size; c < C; c += WARP_SIZE * x128::size) { + const x128 in_data = load128cs(inp + c); + s_in[c / x128::size] = in_data; + for(int k = 0; k < x128::size; ++k) { + float data_k = (float)in_data[k]; + acc += data_k * data_k; + } + } + + acc = warpReduceSum(acc) / C; + float s = rsqrtf(acc + eps); + + for(int c = threadIdx.x * x128::size; c < C; c += WARP_SIZE * x128::size) { + const x128 in_data = s_in[c / x128::size]; + const x128 w = s_weight[c / x128::size]; + x128 out_data; + for(int k = 0; k < x128::size; ++k) { + float n = s * (float)in_data[k]; // normalized output + float o = n * (float)w[k]; // scale + out_data[k] = (floatX)o; + } + + store128cs(out + c, out_data); + } + + // store the rms, no need to cache it + if(threadIdx.x == 0 && rms != nullptr) { + __stcs(rms + idx, s); + } +} + __global__ void fused_residual_forward_kernel5(floatX* residual, floatX* normed, float* mean, float* rstd, const floatX* inp1, const floatX* inp2, const floatX* weight, const floatX* bias, @@ -503,3 +563,27 @@ void layernorm_backward(floatX* dinp, floatX* dweight, floatX* dbias, float* scr layernorm_backward_kernel10<<>>(dinp, dweight, dbias, scratch, dout, inp, weight, mean, rstd, B, T, C); cudaCheck(cudaGetLastError()); } + +void rmsnorm_forward(floatX* out, float* rms, + floatX* inp, const floatX* weight, + int B, int T, int C, cudaStream_t stream) { + NVTX_RANGE_FN(); + const int block_size = 256; + int block_y = block_size / WARP_SIZE; + const int N = B * T; + const int grid_size = CEIL_DIV(N, block_y); + size_t smem = (1 + block_y) * C * sizeof(floatX); + + // in order to use more than 48 KiB of smem, need to call cudaFuncSetAttribute + // this may fail, in which case we fall back to the smem free implementation. + cudaCheck(cudaGetLastError()); + auto status = cudaFuncSetAttribute(rmsnorm_forward_kernel6, cudaFuncAttributeMaxDynamicSharedMemorySize, smem); + cudaCheck(cudaGetLastError()); + if (status == cudaSuccess) { + rmsnorm_forward_kernel6<<>>(out, rms, inp, weight, N, C); + } else { + // We should not allow for these perf regressions for now - just throw an error + assert(false); + } + cudaCheck(cudaGetLastError()); +} diff --git a/train_llama3.cu b/train_llama3.cu index 9edf5d3f7..6bc7e61f1 100644 --- a/train_llama3.cu +++ b/train_llama3.cu @@ -46,6 +46,7 @@ GPT-2 Transformer Neural Net training loop. See README.md for usage. // defines: encoder_forward, encoder_backward #include "llmc/encoder.cuh" // defines: layernorm_forward, residual_forward, fused_residual_forward5, layernorm_backward +// defines: rmsnorm_forward #include "llmc/layernorm.cuh" // defines: matmul_cublaslt, matmul_forward, matmul_backward, gelu_forward, gelu_backward_inplace #include "llmc/matmul.cuh" @@ -620,20 +621,21 @@ void gpt2_forward(GPT2 *model, const int* inputs, size_t B, size_t T) { ActivationTensors acts = model->acts; encoder_forward(acts.encoded, model->inputs, params.wte, NULL, B, T, C, main_stream); // encoding goes into residual[0] + // first layernorm isn't fused + rmsnorm_forward((model->recompute < 2) ? acts.ln1 : acts.lnf, acts.ln1_rstd, acts.encoded, params.ln1w, B, T, C, main_stream); + // ------------------------------------------------------------------------ // DEBUGGING: we only work until this point right now, so exit here // transfer the first 32 elements to CPU and print them + floatX* output = (model->recompute < 2) ? acts.ln1 : acts.lnf; floatX* cpu = (floatX*)mallocCheck(32 * sizeof(floatX)); - cudaCheck(cudaMemcpy(cpu, acts.encoded, 32 * sizeof(floatX), cudaMemcpyDeviceToHost)); + cudaCheck(cudaMemcpy(cpu, output, 32 * sizeof(floatX), cudaMemcpyDeviceToHost)); for (int i = 0; i < 32; i++) { - printf("cpu[%d] = %f\n", i, (float) cpu[i]); + printf("cpu[%d] = %.8f\n", i, (float) cpu[i]); } exit(0); // ------------------------------------------------------------------------ - // first layernorm isn't fused - layernorm_forward((model->recompute < 2) ? acts.ln1 : acts.lnf, acts.ln1_mean, acts.ln1_rstd, acts.encoded, params.ln1w, params.ln1b, B, T, C, main_stream); - for (int l = 0; l < L; l++) { NvtxRange layer_range("Layer", l); diff --git a/train_llama3.py b/train_llama3.py index ca324cd8a..3d7d262bf 100644 --- a/train_llama3.py +++ b/train_llama3.py @@ -166,6 +166,14 @@ def __init__(self, config): def forward(self, x, freqs_cis=None, start_pos=None, mask=None): B, T, C = x.size() # batch size, sequence length, embedding dimensionality (n_embd) + + # --------------------------------------------------------------------- + # DEBUGGING: print first 32 elements of x + for i in range(32): + print("acts[{}]: {}".format(i, x.view(-1)[i].item())) + breakpoint() + # --------------------------------------------------------------------- + # calculate query, key, values for all heads in batch and move head forward to be the batch dim qkv = self.c_attn(x) q, k, v = qkv.split([self.n_head * self.hd, self.n_kv_head * self.hd, self.n_kv_head * self.hd], dim=-1) @@ -299,14 +307,6 @@ def forward(self, idx, targets=None, return_logits=True, start_pos=0): # forward the LLaMA model itself x = self.transformer.wte(idx) # token embeddings of shape (b, t, n_embd) freqs_cis = self.freqs_cis[start_pos:start_pos+t] - - # --------------------------------------------------------------------- - # DEBUGGING: print first 32 elements of x - for i in range(32): - print("acts[{}]: {}".format(i, x.view(-1)[i].item())) - breakpoint() - # --------------------------------------------------------------------- - mask = torch.triu(torch.ones((t, t), device=next(self.parameters()).device, dtype=torch.bool), diagonal=1) for i, block in enumerate(self.transformer.h): From 508c474bf9646f0929d78ed357adec104400b610 Mon Sep 17 00:00:00 2001 From: Andrej Karpathy Date: Tue, 17 Sep 2024 21:19:46 +0000 Subject: [PATCH 09/63] move debugging into fp32, so python has to write the fp32 version, and then we are focusing on the non-cudnn path at first. we're currently right after the first rmsnorm. the encoding right before this matched EXACTLY. but right now, after the first rmsnorm there is already an error of 1e-3 or so, which is highly suspicious so we are looking into it. --- train_llama3.cu | 43 +++++++++++++++++++++++-------------------- train_llama3.py | 5 ++--- 2 files changed, 25 insertions(+), 23 deletions(-) diff --git a/train_llama3.cu b/train_llama3.cu index 6bc7e61f1..da5b86c44 100644 --- a/train_llama3.cu +++ b/train_llama3.cu @@ -240,14 +240,19 @@ struct TensorSpec { #define TENSOR_SPEC(pointer, size) TensorSpec{(void**)(&pointer), (size), dtype_of(pointer)}; void fill_in_activation_sizes(const ActivationTensors* data, TensorSpec (&tensors)[NUM_ACTIVATION_TENSORS], size_t B, size_t T, GPT2Config config, int recompute) { - size_t Vp = config.padded_vocab_size; - size_t L = config.num_layers; - size_t NH = config.num_heads; - size_t C = config.channels; + const size_t Vp = config.padded_vocab_size; + const size_t L = config.num_layers; + const size_t NH = config.num_heads; + const size_t C = config.channels; + const size_t n_head = config.num_heads; + const size_t n_kn_head = config.num_kv_heads; + const size_t hd = C / n_head; // head dimension + const size_t qkv_channels = (n_head + 2*n_kn_head) * hd; // Q, K, V channels + tensors[0] = TENSOR_SPEC(data->encoded, B * T * C); // if recompute >= 1 then we will recompute the layernorm forward activation during backward pass tensors[1] = TENSOR_SPEC(data->ln1, (recompute < 2) ? L * B * T * C : 0); - tensors[2] = TENSOR_SPEC(data->ln1_mean, L * B * T); + tensors[2] = TENSOR_SPEC(data->ln1_mean, 0); // Llama 3 does not use this activation tensors[3] = TENSOR_SPEC(data->ln1_rstd, L * B * T); tensors[4] = TENSOR_SPEC(data->atty, L * B * T * C); #ifdef ENABLE_CUDNN @@ -269,9 +274,8 @@ void fill_in_activation_sizes(const ActivationTensors* data, TensorSpec (&tensor tensors[14] = TENSOR_SPEC(data->lnf_mean, B * T); tensors[15] = TENSOR_SPEC(data->lnf_rstd, B * T); tensors[16] = TENSOR_SPEC(data->losses, B * T); - tensors[17] = TENSOR_SPEC(data->qkvr, L * B * T * 3*C); - tensors[18] = TENSOR_SPEC(data->output, B * T * max(3*C, max(NH*T, Vp))); - + tensors[17] = TENSOR_SPEC(data->qkvr, L * B * T * qkv_channels); + tensors[18] = TENSOR_SPEC(data->output, B * T * max(qkv_channels, max(NH*T, Vp))); tensors[19] = TENSOR_SPEC(data->scratch_bt4c, B * T * 4 * C); tensors[20] = TENSOR_SPEC(data->scratch_btc, B * T * C); } @@ -281,17 +285,13 @@ void* malloc_and_point_activations(TensorSpec (&tensors)[NUM_ACTIVATION_TENSORS] for (size_t i = 0; i < NUM_ACTIVATION_TENSORS; i++) { bytes += tensors[i].size * sizeof_dtype(tensors[i].type); } - printf0("allocating %d MiB for activations\n", (int)round(bytes / (1024 * 1024))); - void* acts_memory; cudaCheck(cudaMalloc((void**)&acts_memory, bytes)); - // cudaMalloc does not guarantee initial memory values so we memset the allocation here // this matters because e.g. non-cuDNN attention assumes the attention buffer is zeroed // todo - up to ~100ms on slow GPUs, could theoretically be more selective, but this is safer cudaCheck(cudaMemset(acts_memory, 0, bytes)); - char* acts_memory_iterator = (char*)acts_memory; for (size_t i = 0; i < NUM_ACTIVATION_TENSORS; i++) { // extra protection so we don't accidentally use an empty buffer @@ -602,6 +602,10 @@ void gpt2_forward(GPT2 *model, const int* inputs, size_t B, size_t T) { const size_t L = model->config.num_layers; const size_t NH = model->config.num_heads; const size_t C = model->config.channels; + const size_t n_head = model->config.num_heads; + const size_t n_kn_head = model->config.num_kv_heads; + const size_t hd = C / n_head; // head dimension + const size_t qkv_channels = (n_head + 2*n_kn_head) * hd; // Q, K, V channels // validate B,T are not larger than the values used at initialisation // (smaller B,T are okay for inference only) @@ -620,7 +624,6 @@ void gpt2_forward(GPT2 *model, const int* inputs, size_t B, size_t T) { ParameterTensors params = model->params; // for brevity ActivationTensors acts = model->acts; encoder_forward(acts.encoded, model->inputs, params.wte, NULL, B, T, C, main_stream); // encoding goes into residual[0] - // first layernorm isn't fused rmsnorm_forward((model->recompute < 2) ? acts.ln1 : acts.lnf, acts.ln1_rstd, acts.encoded, params.ln1w, B, T, C, main_stream); @@ -642,8 +645,8 @@ void gpt2_forward(GPT2 *model, const int* inputs, size_t B, size_t T) { floatX* residual = l == 0 ? acts.encoded : acts.residual3 + (l-1) * B * T * C; // get the pointers of the weights for this layer - floatX* l_qkvw = params.qkvw + l * 3*C * C; - floatX* l_qkvb = params.qkvb + l * 3*C; + floatX* l_qkvw = params.qkvw + l * qkv_channels * C; + floatX* l_qkvb = params.qkvb + l * qkv_channels; floatX* l_attprojw = params.attprojw + l * C * C; floatX* l_attprojb = params.attprojb + l * C; floatX* l_ln2w = params.ln2w + l * C; @@ -655,7 +658,7 @@ void gpt2_forward(GPT2 *model, const int* inputs, size_t B, size_t T) { // get the pointers of the activations for this layer floatX* l_ln1 = (model->recompute < 2) ? acts.ln1 + l * B * T * C : acts.lnf; - floatX* l_qkvr = acts.qkvr + l * B * T * 3*C; + floatX* l_qkvr = acts.qkvr + l * B * T * qkv_channels; floatX* l_atty = acts.atty + l * B * T * C; floatX* l_residual2 = acts.residual2 + l * B * T * C; floatX* l_ln2 = (model->recompute < 2) ? acts.ln2 + l * B * T * C : acts.lnf; @@ -668,10 +671,10 @@ void gpt2_forward(GPT2 *model, const int* inputs, size_t B, size_t T) { floatX* l_residual3 = acts.residual3 + l * B * T * C; floatX* scratch = (floatX*)acts.output; // used for non-cudnn attention, fcproj, attproj, etc. - // now do the forward pass + // now start the block forward pass #ifdef ENABLE_CUDNN + matmul_forward_cublaslt(l_qkvr, l_ln1, l_qkvw, l_qkvb, B, T, C, qkv_channels, main_stream); float* l_att = (float*)acts.att + l * B * NH * T; // cuDNN needs a smaller FP32 tensor - matmul_forward_cublaslt(l_qkvr, l_ln1, l_qkvw, l_qkvb, B, T, C, 3*C, main_stream); attention_forward_cudnn(l_atty, (float*)l_att, l_qkvr, B, T, NH, C, main_stream); #else floatX* l_att = acts.att + l * B * NH * T * T; @@ -680,7 +683,7 @@ void gpt2_forward(GPT2 *model, const int* inputs, size_t B, size_t T) { } // these are only needed as scratchpads for the forward pass, but // need not be stored for backward - matmul_forward_cublaslt(scratch, l_ln1, l_qkvw, l_qkvb, B, T, C, 3*C, main_stream); + matmul_forward_cublaslt(scratch, l_ln1, l_qkvw, l_qkvb, B, T, C, qkv_channels, main_stream); attention_forward(l_atty, l_qkvr, l_att, scratch, B, T, C, NH, main_stream); #endif @@ -1374,7 +1377,7 @@ int main(int argc, char *argv[]) { // read in the (optional) command line arguments const char* train_data_pattern = "dev/data/tinyshakespeare/tiny_shakespeare_train.bin"; const char* val_data_pattern = "dev/data/tinyshakespeare/tiny_shakespeare_val.bin"; - const char* load_filename = "llama3.1_8B_bf16.bin"; // bf16 weights of the Llama 3.1 8B model + const char* load_filename = "llama3.1_8B.bin"; // bf16 weights of the Llama 3.1 8B model const char* lr_scheduler_type = "cosine"; const char* output_log_dir = NULL; int checkpoint_every = 0; // write checkpoints every how many steps? diff --git a/train_llama3.py b/train_llama3.py index 3d7d262bf..62c52a959 100644 --- a/train_llama3.py +++ b/train_llama3.py @@ -170,7 +170,7 @@ def forward(self, x, freqs_cis=None, start_pos=None, mask=None): # --------------------------------------------------------------------- # DEBUGGING: print first 32 elements of x for i in range(32): - print("acts[{}]: {}".format(i, x.view(-1)[i].item())) + print("acts[{}]: {:.8f}".format(i, x.view(-1)[i].item())) breakpoint() # --------------------------------------------------------------------- @@ -178,9 +178,7 @@ def forward(self, x, freqs_cis=None, start_pos=None, mask=None): qkv = self.c_attn(x) q, k, v = qkv.split([self.n_head * self.hd, self.n_kv_head * self.hd, self.n_kv_head * self.hd], dim=-1) q, k, v = map(lambda t: t.view(B, T, -1, self.hd), (q, k, v)) # (B, T, NH, HD) - q, k = apply_rotary_emb(q, k, freqs_cis=freqs_cis) # rotate QK (rope) <-- 1. difference compared to GPT-2 - if self.use_kv and not self.training and start_pos >= 0: # use kv-caching during inference self.cache_k[:B, start_pos : start_pos + T] = k self.cache_v[:B, start_pos : start_pos + T] = v @@ -1131,6 +1129,7 @@ def print0(*args, **kwargs): # save model params, in bfloat16 model_to_size = {"meta-llama/Meta-Llama-3.1-8B": "8B"} model_size_str = model_to_size[args.model] # e.g. "8B" + write_model(model, os.path.join(args.output_dir, f"llama3.1_{model_size_str}.bin"), dtype="float32") write_model(model, os.path.join(args.output_dir, f"llama3.1_{model_size_str}_bf16.bin"), dtype="bfloat16") # save x, y, logits, loss, and parameter gradients, for debugging C # always store these in fp32 to have an accurate reference (?) From 685617f1646c60ea851619167eae1eb756ec22ac Mon Sep 17 00:00:00 2001 From: Andrej Karpathy Date: Tue, 17 Sep 2024 21:31:18 +0000 Subject: [PATCH 10/63] make fp32 path in .py code work correctly --- train_llama3.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/train_llama3.py b/train_llama3.py index 62c52a959..08a4f7fb0 100644 --- a/train_llama3.py +++ b/train_llama3.py @@ -1101,6 +1101,10 @@ def print0(*args, **kwargs): assert args.tokenizer_path is not None and os.path.exists(args.tokenizer_path), f"llama3 tokenizer path {args.tokenizer_path} does not exist" model = LLaMA.from_pretrained_llama3_meta(args.ckpt_dir, args.tokenizer_path) + # convert the model to the desired precision + if args.dtype == "float32": + model = model.to(torch.float32) + model.train() if args.compile: if hasattr(config, "coordinate_descent_tuning"): From 56f956cc4a0c7d551e20e1026d52228ca075163f Mon Sep 17 00:00:00 2001 From: Andrej Karpathy Date: Sat, 21 Sep 2024 01:53:08 +0000 Subject: [PATCH 11/63] add repkv kernel to replicate K,V heads after the QKV projection --- dev/cuda/Makefile | 3 +- dev/cuda/repkv.cu | 188 ++++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 190 insertions(+), 1 deletion(-) create mode 100644 dev/cuda/repkv.cu diff --git a/dev/cuda/Makefile b/dev/cuda/Makefile index 6a7584f8d..5e2fada82 100644 --- a/dev/cuda/Makefile +++ b/dev/cuda/Makefile @@ -30,7 +30,7 @@ MPI_PATHS = -I/usr/lib/x86_64-linux-gnu/openmpi/include -L/usr/lib/x86_64-linux- $(NVCC) $(CFLAGS) $(NVCCFLAGS) $< -o $@ # Build all targets -TARGETS = adamw attention_backward attention_forward classifier_fused crossentropy_forward crossentropy_softmax_backward encoder_backward encoder_forward gelu_backward gelu_forward layernorm_backward layernorm_forward matmul_backward matmul_backward_bias matmul_forward nccl_all_reduce residual_forward softmax_forward trimat_forward fused_residual_forward global_norm permute +TARGETS = adamw attention_backward attention_forward classifier_fused crossentropy_forward crossentropy_softmax_backward encoder_backward encoder_forward gelu_backward gelu_forward layernorm_backward layernorm_forward matmul_backward matmul_backward_bias matmul_forward nccl_all_reduce residual_forward softmax_forward trimat_forward fused_residual_forward global_norm permute repkv all: $(TARGETS) all_ptx: $(TARGETS:%=%.ptx) @@ -66,6 +66,7 @@ adamw: adamw.cu global_norm: global_norm.cu permute: permute.cu +repkv: repkv.cu # NCCL communication kernels nccl_all_reduce: nccl_all_reduce.cu diff --git a/dev/cuda/repkv.cu b/dev/cuda/repkv.cu new file mode 100644 index 000000000..aa816294c --- /dev/null +++ b/dev/cuda/repkv.cu @@ -0,0 +1,188 @@ +/* +Layer that takes a QKV tensor of shape (B, T, C) and replicates the K,V +some number of times. For example, if B=4, T=64, C=6144, and we have that: +- head dimension (hd) is 128 channels +- query heads: 32 +- key heads: 8 +- value heads: 8 +- so number of heads = 32 + 8 + 8 = 48, each of 128 channels, total of 6144 channels +We want to replicate the key/value vectors 4X, so that we get: +32 + 32 + 32 = 96 query, key, value heads, each of 128 channels, total of 12288 channels +Each of these vectors should be replicated by simple copying/concat 4X times. + +Compile and run as: +make repkv +./repkv + +block_size 128 seems fastest on H100 +*/ + +#include +#include +#include +#include +#include "common.h" + +// cpu reference code +void repkv_forward_cpu(float* out, const float* inp, + int B, int T, int C, + int hd, int qh, int kh, int vh) { + // inp is (B, T, C) + // hd = head dimension + // qh, kh, vh = number of query, key, value heads + assert(C == hd * (qh + kh + vh)); + assert(kh == vh); + int nrep = qh / kh; // number of times to replicate key/value vectors + int Cout = hd * (qh * 3); // output channels + + for (int b = 0; b < B; b++) { + for (int t = 0; t < T; t++) { + // seek to the input position inp[b,t,:] + const float* x = inp + b * T * C + t * C; + // seek to the output position out[b,t,:] + float* y = out + b * T * Cout + t * Cout; + // copy all the query vectors, no changes + for (int i = 0; i < hd * qh; i++) { y[i] = x[i]; } + x += hd * qh; // advance input pointer + y += hd * qh; // advance output pointer + // copy key vectors, and replicate them nrep times + for (int h = 0; h < kh; h++) { + for (int n = 0; n < nrep; n++) { + for (int i = 0; i < hd; i++) { y[i] = x[i]; } + y += hd; // advance output pointer + } + x += hd; // advance input pointer + } + // copy value vectors, and replicate them nrep times + for (int h = 0; h < vh; h++) { + for (int n = 0; n < nrep; n++) { + for (int i = 0; i < hd; i++) { y[i] = x[i]; } + y += hd; // advance output pointer + } + x += hd; // advance input pointer + } + } + } +} + +// kernels +__global__ void repkv_forward_kernel1(floatX* replicated_qkv, + const floatX* gqa_qkv, + int B, int N, int NH, int replicate_factor, int HD) { + // we have a single tensor gqa_qkv of shape (B, N, (NH + 2*(NH/replicate_factor)) * HD) + // we want to replicate it into (B, N, 3 * NH * HD) + int idx = blockIdx.x * blockDim.x + threadIdx.x; + if (idx >= B * N * 3 * NH * HD) { return; } + int idx_flat = idx; // keep backup + + // decode the output index + int d = idx % HD; + idx /= HD; + int nh = idx % NH; + idx /= NH; + int c = idx % 3; + idx /= 3; + int n = idx % N; + int b = idx / N; + + int inp_idx; + int nh_total = NH + 2 * (NH / replicate_factor); + if (c == 0) { + inp_idx = b * N * nh_total * HD + n * nh_total * HD + 0 * NH * HD + nh * HD + d; + } else if (c == 1) { + inp_idx = b * N * nh_total * HD + n * nh_total * HD + 1 * NH * HD + (nh / replicate_factor) * HD + d; + } else { + inp_idx = b * N * nh_total * HD + n * nh_total * HD + (NH * HD + (NH / replicate_factor) * HD) + (nh / replicate_factor) * HD + d; + } + + replicated_qkv[idx_flat] = __ldcs(&gqa_qkv[inp_idx]); +} + +// kernel launchers +void repkv_forward1(floatX* out, const floatX* inp, int B, int T, int NH, int NH_KV, int d, int block_size) { + int total_threads = B * T * (3 * NH) * d; + int num_blocks = ceil_div(total_threads, block_size); + int replicate_factor = NH / NH_KV; + repkv_forward_kernel1<<>>(out, inp, B, T, NH, replicate_factor, d); + cudaCheck(cudaGetLastError()); +} + +// kernel dispatcher +void repkv_forward(int kernel_num, + floatX* out, const floatX* inp, + int B, int T, int NH, int NH_KV, int d, + int block_size) { + switch (kernel_num) { + case 1: + repkv_forward1(out, inp, B, T, NH, NH_KV, d, block_size); + break; + default: + printf("Invalid kernel number\n"); + exit(1); + } +} + +// tester +int main(int argc, char **argv) { + srand(0); + + int B = 8; + int T = 1024; + int hd = 128; // head dim + int qh = 32; // num query heads + int kh = 8; // num key heads + int vh = 8; // num value heads + + int deviceIdx = 0; + cudaCheck(cudaSetDevice(deviceIdx)); + + int C = hd * (qh + kh + vh); // input channels + int Cout = hd * (qh * 3); // output channels + + // allocate (and fill) CPU memory + float* inp = make_random_float(B * T * C); + float* out = (float*)malloc(B * T * Cout * sizeof(float)); + + // allocate GPU memory + float* d_inp; + float* d_out; + cudaCheck(cudaMalloc(&d_inp, B * T * C * sizeof(float))); + cudaCheck(cudaMalloc(&d_out, B * T * Cout * sizeof(float))); + + // read kernel_num from command line + int kernel_num = 1; + if (argc > 1) { + kernel_num = atoi(argv[1]); + } + printf("Using kernel %d\n", kernel_num); + + // CPU reference calculate + repkv_forward_cpu(out, inp, B, T, C, hd, qh, kh, vh); + + // check the correctness of the kernel at all block sizes + int block_sizes[] = {32, 64, 128, 256, 512, 1024}; + cudaCheck(cudaMemcpy(d_inp, inp, B * T * C * sizeof(float), cudaMemcpyHostToDevice)); + for (int j = 0; j < sizeof(block_sizes) / sizeof(int); j++) { + int block_size = block_sizes[j]; + printf("Checking block size %d.\n", block_size); + repkv_forward(kernel_num, d_out, d_inp, B, T, qh, kh, hd, block_size); + validate_result(d_out, out, "out", B * T * Cout, 1e-5f); + } + printf("All results match. Starting benchmarks.\n\n"); + + // now benchmark + for (int j = 0; j < sizeof(block_sizes) / sizeof(int); j++) { + int block_size = block_sizes[j]; + int repeat_times = 1000; + float elapsed_time = benchmark_kernel(repeat_times, repkv_forward, kernel_num, + d_out, d_inp, B, T, qh, kh, hd, block_size); + printf("block_size %4d time %.4f ms\n", block_size, elapsed_time); + } + + // free memory + free(inp); + free(out); + cudaCheck(cudaFree(d_inp)); + cudaCheck(cudaFree(d_out)); +} + From 45401b42eb3d179a503549603192d68d7c87b588 Mon Sep 17 00:00:00 2001 From: Insop Song Date: Sat, 21 Sep 2024 20:40:57 -0700 Subject: [PATCH 12/63] DRAFT: Adding backward kernel for repkv - [ ] WIP: CPU kernel - [ ] Cuda kernel --- dev/cuda/Makefile | 1 + dev/cuda/repkv_backward.cu | 266 +++++++++++++++++++++++++++++++++++++ 2 files changed, 267 insertions(+) create mode 100644 dev/cuda/repkv_backward.cu diff --git a/dev/cuda/Makefile b/dev/cuda/Makefile index 5e2fada82..0fc69728a 100644 --- a/dev/cuda/Makefile +++ b/dev/cuda/Makefile @@ -67,6 +67,7 @@ global_norm: global_norm.cu permute: permute.cu repkv: repkv.cu +repkv_backward: repkv_backward.cu # NCCL communication kernels nccl_all_reduce: nccl_all_reduce.cu diff --git a/dev/cuda/repkv_backward.cu b/dev/cuda/repkv_backward.cu new file mode 100644 index 000000000..4694a0806 --- /dev/null +++ b/dev/cuda/repkv_backward.cu @@ -0,0 +1,266 @@ +/* + +TODO: update after CPU kernel + +Layer that takes a QKV tensor of shape (B, T, C) and replicates the K,V +some number of times. For example, if B=4, T=64, C=6144, and we have that: +- head dimension (hd) is 128 channels +- query heads: 32 +- key heads: 8 +- value heads: 8 +- so number of heads = 32 + 8 + 8 = 48, each of 128 channels, total of 6144 channels +We want to replicate the key/value vectors 4X, so that we get: +32 + 32 + 32 = 96 query, key, value heads, each of 128 channels, total of 12288 channels +Each of these vectors should be replicated by simple copying/concat 4X times. + +Compile and run as: +make repkv +./repkv + +block_size 128 seems fastest on H100 +*/ + +#include +#include +#include +#include +#include "common.h" + +// cpu reference code +void repkv_backward_cpu(float* dinp, const float* inp, const float* dout, + const int B, const int T, const int Cout, + const int hd, const int qh, const int kh, const int vh) { + + assert(Cout == (hd * (3 * qh))); + assert(kh == vh); + // assert((kh % nrep == 0) && (vh % nrep == 0)); + // int kh_g = kh / nrep; + // int vh_g = vh / nrep; + + int nrep = qh / kh; // number of times to replicate key/value vectors + + int Cin = hd * (qh + kh + vh); // output channels + + for (int b = 0; b < B; b++) { + for (int t = 0; t < T; t++) { + // seek to the input position dout[b,t,:] + // TODO check + const float* x = dout + b * T * Cout + t * Cout; + // seek to the output position out[b,t,:] + float* y = dinp + b * T * Cin + t * Cin; + // copy all the query vectors, no changes + for (int i = 0; i < hd * qh; i++) { y[i] = x[i]; } + x += hd * qh; // advance input pointer + y += hd * qh; // advance output pointer + // copy key vectors, and replicate them nrep times + for (int h = 0; h < kh; h++) { + // initilize + // for (int i = 0; i < hd; i++) { y[i] = 0.0f; } + for (int n = 0; n < nrep; n++) { + for (int i = 0; i < hd; i++) { y[i] += x[i]; } + x += hd; // advance input pointer + } + y += hd; // advance output pointer + } + // copy value vectors, and replicate them nrep times + for (int h = 0; h < vh; h++) { + // initilize + // for (int i = 0; i < hd; i++) { y[i] = 0.0f; } + for (int n = 0; n < nrep; n++) { + for (int i = 0; i < hd; i++) { y[i] += x[i]; } + x += hd; // advance input pointer + } + y += hd; // advance output pointer + } + } + } +} + +// TODO: update after CPU kernel +// kernels +__global__ void repkv_backward_kernel1(floatX* replicated_qkv, + const floatX* gqa_qkv, + int B, int N, int NH, int replicate_factor, int HD) { + // we have a single tensor gqa_qkv of shape (B, N, (NH + 2*(NH/replicate_factor)) * HD) + // we want to replicate it into (B, N, 3 * NH * HD) + int idx = blockIdx.x * blockDim.x + threadIdx.x; + if (idx >= B * N * 3 * NH * HD) { return; } + int idx_flat = idx; // keep backup + + // decode the output index + int d = idx % HD; + idx /= HD; + int nh = idx % NH; + idx /= NH; + int c = idx % 3; + idx /= 3; + int n = idx % N; + int b = idx / N; + + int inp_idx; + int nh_total = NH + 2 * (NH / replicate_factor); + if (c == 0) { + inp_idx = b * N * nh_total * HD + n * nh_total * HD + 0 * NH * HD + nh * HD + d; + } else if (c == 1) { + inp_idx = b * N * nh_total * HD + n * nh_total * HD + 1 * NH * HD + (nh / replicate_factor) * HD + d; + } else { + inp_idx = b * N * nh_total * HD + n * nh_total * HD + (NH * HD + (NH / replicate_factor) * HD) + (nh / replicate_factor) * HD + d; + } + + replicated_qkv[idx_flat] = __ldcs(&gqa_qkv[inp_idx]); +} + +// TODO: update after CPU kernel +// kernel launchers +void repkv_backward1(floatX* out, const floatX* inp, int B, int T, int NH, int NH_KV, int d, int block_size) { + int total_threads = B * T * (3 * NH) * d; + int num_blocks = ceil_div(total_threads, block_size); + int replicate_factor = NH / NH_KV; + repkv_backward_kernel1<<>>(out, inp, B, T, NH, replicate_factor, d); + cudaCheck(cudaGetLastError()); +} + +// TODO: update after CPU kernel +// kernel dispatcher +void repkv_backward(int kernel_num, + floatX* out, const floatX* inp, + int B, int T, int NH, int NH_KV, int d, + int block_size) { + switch (kernel_num) { + case 1: + repkv_backward1(out, inp, B, T, NH, NH_KV, d, block_size); + break; + default: + printf("Invalid kernel number\n"); + exit(1); + } +} + +void log_mat(float *inp, int B, int T, int C, int hd, int qh, int kh, int vh, char *title) +{ + printf("%s -----\n", title); + for (int b = 0; b < B; b++) { + printf("batch : %d ", b); + for (int t = 0; t < T; t++) { + printf("t = %d\n", t); + const float* x = inp + b * T * C + t * C; + printf("Query\n"); + for (int h=0; h < qh; h++) { + for (int i = 0; i < hd; i++) { + printf("%f ", x[i]); + } + x += hd; // advance input pointer + printf("\n"); + } + printf("Key\n"); + for (int h=0; h < kh; h++) { + for (int i = 0; i < hd; i++) { + printf("%f ", x[i]); + } + x += hd; // advance input pointer + printf("\n"); + } + printf("Value\n"); + for (int h=0; h < vh; h++) { + for (int i = 0; i < hd; i++) { + printf("%f ", x[i]); + } + x += hd; // advance input pointer + printf("\n"); + } + } + } + printf("\n"); +} + +// TODO: update after CPU kernel +// tester +int main(int argc, char **argv) { + srand(0); + +#ifndef DEBUG + int B = 1; + int T = 2; + int hd = 3; // head dim + int qh = 4; // num query heads + int kh = 2; // num key heads + int vh = 2; // num value heads + int nrep = qh/kh; +#else + int B = 8; + int T = 1024; + int hd = 128; // head dim + int qh = 32; // num query heads + int kh = 8; // num key heads + int vh = 8; // num value heads +#endif + + int deviceIdx = 0; + cudaCheck(cudaSetDevice(deviceIdx)); + + int Cout = hd * (qh * 3); // out, upstream channels + int Cin = hd * (qh + kh + vh); // in, downstream channels + // int nrep = 4; + + // allocate (and fill) CPU memory + float* dinp = (float*)malloc(B * T * Cin * sizeof(float)); + memset(dinp, 0, B * T * Cin * sizeof(float)); + float* inp = make_random_float(B * T * Cin); + float* doutp = make_random_float(B * T * Cout * sizeof(float)); + + // TODO: update after CPU kernel + // allocate GPU memory +#if 0 + float* d_inp; + float* d_out; + cudaCheck(cudaMalloc(&d_inp, B * T * C * sizeof(float))); + cudaCheck(cudaMalloc(&d_out, B * T * Cout * sizeof(float))); +#endif + + // read kernel_num from command line + int kernel_num = 1; + if (argc > 1) { + kernel_num = atoi(argv[1]); + } + printf("Using kernel %d\n", kernel_num); + + log_mat(doutp, B, T, Cout, hd, qh, nrep*kh, nrep*vh, "doutp"); + + // TODO: update + // CPU reference calculate + repkv_backward_cpu(dinp, inp, doutp, B, T, Cout, hd, qh, kh, vh); + + log_mat(dinp, B, T, Cout, hd, qh, kh, vh, "dinp"); + + // TODO: update after CPU kernel +#if 0 + // check the correctness of the kernel at all block sizes + int block_sizes[] = {32, 64, 128, 256, 512, 1024}; + cudaCheck(cudaMemcpy(d_inp, inp, B * T * C * sizeof(float), cudaMemcpyHostToDevice)); + for (int j = 0; j < sizeof(block_sizes) / sizeof(int); j++) { + int block_size = block_sizes[j]; + printf("Checking block size %d.\n", block_size); + repkv_forward(kernel_num, d_out, d_inp, B, T, qh, kh, hd, block_size); + validate_result(d_out, out, "out", B * T * Cout, 1e-5f); + } + printf("All results match. Starting benchmarks.\n\n"); + + // now benchmark + for (int j = 0; j < sizeof(block_sizes) / sizeof(int); j++) { + int block_size = block_sizes[j]; + int repeat_times = 1000; + float elapsed_time = benchmark_kernel(repeat_times, repkv_forward, kernel_num, + d_out, d_inp, B, T, qh, kh, hd, block_size); + printf("block_size %4d time %.4f ms\n", block_size, elapsed_time); + } +#endif + // free memory + free(inp); + free(dinp); + free(doutp); + + // TODO: update after CPU kernel + // cudaCheck(cudaFree(d_inp)); + // cudaCheck(cudaFree(d_out)); +} + From 080e57fd466fc6e5d4c556d2d95c72651c1e6d5b Mon Sep 17 00:00:00 2001 From: Insop Song Date: Sat, 21 Sep 2024 21:21:13 -0700 Subject: [PATCH 13/63] CPU version tested - [ ] WIP cuda version --- dev/cuda/repkv_backward.cu | 54 +++++++++++++++++--------------------- 1 file changed, 24 insertions(+), 30 deletions(-) diff --git a/dev/cuda/repkv_backward.cu b/dev/cuda/repkv_backward.cu index 4694a0806..f35e7712b 100644 --- a/dev/cuda/repkv_backward.cu +++ b/dev/cuda/repkv_backward.cu @@ -27,15 +27,12 @@ block_size 128 seems fastest on H100 #include "common.h" // cpu reference code -void repkv_backward_cpu(float* dinp, const float* inp, const float* dout, +void repkv_backward_cpu(float* dinp, const float* inp, const float* doutp, const int B, const int T, const int Cout, const int hd, const int qh, const int kh, const int vh) { assert(Cout == (hd * (3 * qh))); assert(kh == vh); - // assert((kh % nrep == 0) && (vh % nrep == 0)); - // int kh_g = kh / nrep; - // int vh_g = vh / nrep; int nrep = qh / kh; // number of times to replicate key/value vectors @@ -43,9 +40,8 @@ void repkv_backward_cpu(float* dinp, const float* inp, const float* dout, for (int b = 0; b < B; b++) { for (int t = 0; t < T; t++) { - // seek to the input position dout[b,t,:] - // TODO check - const float* x = dout + b * T * Cout + t * Cout; + // seek to the input position doutp[b,t,:] + const float* x = doutp + b * T * Cout + t * Cout; // seek to the output position out[b,t,:] float* y = dinp + b * T * Cin + t * Cin; // copy all the query vectors, no changes @@ -54,8 +50,6 @@ void repkv_backward_cpu(float* dinp, const float* inp, const float* dout, y += hd * qh; // advance output pointer // copy key vectors, and replicate them nrep times for (int h = 0; h < kh; h++) { - // initilize - // for (int i = 0; i < hd; i++) { y[i] = 0.0f; } for (int n = 0; n < nrep; n++) { for (int i = 0; i < hd; i++) { y[i] += x[i]; } x += hd; // advance input pointer @@ -64,8 +58,6 @@ void repkv_backward_cpu(float* dinp, const float* inp, const float* dout, } // copy value vectors, and replicate them nrep times for (int h = 0; h < vh; h++) { - // initilize - // for (int i = 0; i < hd; i++) { y[i] = 0.0f; } for (int n = 0; n < nrep; n++) { for (int i = 0; i < hd; i++) { y[i] += x[i]; } x += hd; // advance input pointer @@ -112,7 +104,8 @@ __global__ void repkv_backward_kernel1(floatX* replicated_qkv, // TODO: update after CPU kernel // kernel launchers -void repkv_backward1(floatX* out, const floatX* inp, int B, int T, int NH, int NH_KV, int d, int block_size) { +void repkv_backward1(floatX* dinp, const floatX* inp, const floatX* doutp, + const int B, const int T, const int NH, const int NH_KV, const int d, int block_size) { int total_threads = B * T * (3 * NH) * d; int num_blocks = ceil_div(total_threads, block_size); int replicate_factor = NH / NH_KV; @@ -193,6 +186,7 @@ int main(int argc, char **argv) { int qh = 32; // num query heads int kh = 8; // num key heads int vh = 8; // num value heads + int nrep = qh/kh; #endif int deviceIdx = 0; @@ -200,7 +194,6 @@ int main(int argc, char **argv) { int Cout = hd * (qh * 3); // out, upstream channels int Cin = hd * (qh + kh + vh); // in, downstream channels - // int nrep = 4; // allocate (and fill) CPU memory float* dinp = (float*)malloc(B * T * Cin * sizeof(float)); @@ -208,14 +201,16 @@ int main(int argc, char **argv) { float* inp = make_random_float(B * T * Cin); float* doutp = make_random_float(B * T * Cout * sizeof(float)); - // TODO: update after CPU kernel // allocate GPU memory -#if 0 + float* d_dinp; float* d_inp; - float* d_out; - cudaCheck(cudaMalloc(&d_inp, B * T * C * sizeof(float))); - cudaCheck(cudaMalloc(&d_out, B * T * Cout * sizeof(float))); -#endif + float* d_doutp; + cudaCheck(cudaMalloc(&d_dinp, B * T * Cin * sizeof(float))); + cudaCheck(cudaMalloc(&d_inp, B * T * Cin * sizeof(float))); + cudaCheck(cudaMalloc(&d_doutp, B * T * Cout * sizeof(float))); + + // cudaCheck(memcpy_convert(d_inp, inp, B * T * Cin)); + // cudaCheck(memcpy_convert(d_doutp, doutp, B * T * Cout)); // read kernel_num from command line int kernel_num = 1; @@ -226,22 +221,21 @@ int main(int argc, char **argv) { log_mat(doutp, B, T, Cout, hd, qh, nrep*kh, nrep*vh, "doutp"); - // TODO: update // CPU reference calculate repkv_backward_cpu(dinp, inp, doutp, B, T, Cout, hd, qh, kh, vh); - log_mat(dinp, B, T, Cout, hd, qh, kh, vh, "dinp"); + log_mat(dinp, B, T, Cin, hd, qh, kh, vh, "dinp"); // TODO: update after CPU kernel -#if 0 // check the correctness of the kernel at all block sizes int block_sizes[] = {32, 64, 128, 256, 512, 1024}; - cudaCheck(cudaMemcpy(d_inp, inp, B * T * C * sizeof(float), cudaMemcpyHostToDevice)); + cudaCheck(cudaMemcpy(d_inp, inp, B * T * Cin * sizeof(float), cudaMemcpyHostToDevice)); + cudaCheck(cudaMemcpy(d_doutp, doutp, B * T * Cout * sizeof(float), cudaMemcpyHostToDevice)); for (int j = 0; j < sizeof(block_sizes) / sizeof(int); j++) { int block_size = block_sizes[j]; printf("Checking block size %d.\n", block_size); - repkv_forward(kernel_num, d_out, d_inp, B, T, qh, kh, hd, block_size); - validate_result(d_out, out, "out", B * T * Cout, 1e-5f); + repkv_backward(kernel_num, d_out, d_inp, B, T, qh, kh, hd, block_size); + validate_result(d_dinp, dinp, "out", B * T * Cin, 1e-5f); } printf("All results match. Starting benchmarks.\n\n"); @@ -249,18 +243,18 @@ int main(int argc, char **argv) { for (int j = 0; j < sizeof(block_sizes) / sizeof(int); j++) { int block_size = block_sizes[j]; int repeat_times = 1000; - float elapsed_time = benchmark_kernel(repeat_times, repkv_forward, kernel_num, + float elapsed_time = benchmark_kernel(repeat_times, repkv_backward, kernel_num, d_out, d_inp, B, T, qh, kh, hd, block_size); printf("block_size %4d time %.4f ms\n", block_size, elapsed_time); } -#endif + // free memory free(inp); free(dinp); free(doutp); - // TODO: update after CPU kernel - // cudaCheck(cudaFree(d_inp)); - // cudaCheck(cudaFree(d_out)); + cudaCheck(cudaFree(d_dinp)); + cudaCheck(cudaFree(d_inp)); + cudaCheck(cudaFree(d_doutp)); } From 6c68657c6b33704d9582c95c5db7e81f274b8e71 Mon Sep 17 00:00:00 2001 From: Insop Song Date: Sat, 21 Sep 2024 21:30:45 -0700 Subject: [PATCH 14/63] Put cuda kernel caller placeholder --- dev/cuda/repkv_backward.cu | 26 ++++++++++++++++---------- 1 file changed, 16 insertions(+), 10 deletions(-) diff --git a/dev/cuda/repkv_backward.cu b/dev/cuda/repkv_backward.cu index f35e7712b..24fbe4e73 100644 --- a/dev/cuda/repkv_backward.cu +++ b/dev/cuda/repkv_backward.cu @@ -68,11 +68,13 @@ void repkv_backward_cpu(float* dinp, const float* inp, const float* doutp, } } -// TODO: update after CPU kernel // kernels -__global__ void repkv_backward_kernel1(floatX* replicated_qkv, - const floatX* gqa_qkv, - int B, int N, int NH, int replicate_factor, int HD) { +__global__ void repkv_backward_kernel1(floatX* dinp, + const floatX* inp, const floatX* doutp, + int B, int N, int NH, int replicate_factor, int HD) { + + // TODO: update after CPU kernel +#if 0 // we have a single tensor gqa_qkv of shape (B, N, (NH + 2*(NH/replicate_factor)) * HD) // we want to replicate it into (B, N, 3 * NH * HD) int idx = blockIdx.x * blockDim.x + threadIdx.x; @@ -100,6 +102,7 @@ __global__ void repkv_backward_kernel1(floatX* replicated_qkv, } replicated_qkv[idx_flat] = __ldcs(&gqa_qkv[inp_idx]); +#endif } // TODO: update after CPU kernel @@ -109,19 +112,19 @@ void repkv_backward1(floatX* dinp, const floatX* inp, const floatX* doutp, int total_threads = B * T * (3 * NH) * d; int num_blocks = ceil_div(total_threads, block_size); int replicate_factor = NH / NH_KV; - repkv_backward_kernel1<<>>(out, inp, B, T, NH, replicate_factor, d); + repkv_backward_kernel1<<>>(dinp, inp, doutp, B, T, NH, replicate_factor, d); cudaCheck(cudaGetLastError()); } // TODO: update after CPU kernel // kernel dispatcher void repkv_backward(int kernel_num, - floatX* out, const floatX* inp, + floatX* dinp, const floatX* inp, const floatX* doutp, int B, int T, int NH, int NH_KV, int d, int block_size) { switch (kernel_num) { case 1: - repkv_backward1(out, inp, B, T, NH, NH_KV, d, block_size); + repkv_backward1(dinp, inp, doutp, B, T, NH, NH_KV, d, block_size); break; default: printf("Invalid kernel number\n"); @@ -209,7 +212,7 @@ int main(int argc, char **argv) { cudaCheck(cudaMalloc(&d_inp, B * T * Cin * sizeof(float))); cudaCheck(cudaMalloc(&d_doutp, B * T * Cout * sizeof(float))); - // cudaCheck(memcpy_convert(d_inp, inp, B * T * Cin)); + // cudaCheck(memcpy_convert(d_dinp, inp, B * T * Cin)); // cudaCheck(memcpy_convert(d_doutp, doutp, B * T * Cout)); // read kernel_num from command line @@ -229,16 +232,18 @@ int main(int argc, char **argv) { // TODO: update after CPU kernel // check the correctness of the kernel at all block sizes int block_sizes[] = {32, 64, 128, 256, 512, 1024}; - cudaCheck(cudaMemcpy(d_inp, inp, B * T * Cin * sizeof(float), cudaMemcpyHostToDevice)); + cudaCheck(cudaMemcpy(d_dinp, inp, B * T * Cin * sizeof(float), cudaMemcpyHostToDevice)); cudaCheck(cudaMemcpy(d_doutp, doutp, B * T * Cout * sizeof(float), cudaMemcpyHostToDevice)); for (int j = 0; j < sizeof(block_sizes) / sizeof(int); j++) { int block_size = block_sizes[j]; printf("Checking block size %d.\n", block_size); - repkv_backward(kernel_num, d_out, d_inp, B, T, qh, kh, hd, block_size); + repkv_backward(kernel_num, d_dinp, d_inp, d_doutp, B, T, qh, kh, hd, block_size); validate_result(d_dinp, dinp, "out", B * T * Cin, 1e-5f); } printf("All results match. Starting benchmarks.\n\n"); + // TODO: update +#if 0 // now benchmark for (int j = 0; j < sizeof(block_sizes) / sizeof(int); j++) { int block_size = block_sizes[j]; @@ -247,6 +252,7 @@ int main(int argc, char **argv) { d_out, d_inp, B, T, qh, kh, hd, block_size); printf("block_size %4d time %.4f ms\n", block_size, elapsed_time); } +#endif // free memory free(inp); From ad46043aaf3efeb9dc4fb216212863fbf551c842 Mon Sep 17 00:00:00 2001 From: Insop Song Date: Sun, 22 Sep 2024 00:22:26 -0700 Subject: [PATCH 15/63] WIP updating cuda kernel --- dev/cuda/repkv_backward.cu | 41 +++++--------------------------------- 1 file changed, 5 insertions(+), 36 deletions(-) diff --git a/dev/cuda/repkv_backward.cu b/dev/cuda/repkv_backward.cu index 24fbe4e73..ec03fce41 100644 --- a/dev/cuda/repkv_backward.cu +++ b/dev/cuda/repkv_backward.cu @@ -1,6 +1,6 @@ /* -TODO: update after CPU kernel +TODO: update the description Layer that takes a QKV tensor of shape (B, T, C) and replicates the K,V some number of times. For example, if B=4, T=64, C=6144, and we have that: @@ -14,8 +14,8 @@ We want to replicate the key/value vectors 4X, so that we get: Each of these vectors should be replicated by simple copying/concat 4X times. Compile and run as: -make repkv -./repkv +make repkv_backward +./repkv_backward 1 block_size 128 seems fastest on H100 */ @@ -74,38 +74,8 @@ __global__ void repkv_backward_kernel1(floatX* dinp, int B, int N, int NH, int replicate_factor, int HD) { // TODO: update after CPU kernel -#if 0 - // we have a single tensor gqa_qkv of shape (B, N, (NH + 2*(NH/replicate_factor)) * HD) - // we want to replicate it into (B, N, 3 * NH * HD) - int idx = blockIdx.x * blockDim.x + threadIdx.x; - if (idx >= B * N * 3 * NH * HD) { return; } - int idx_flat = idx; // keep backup - - // decode the output index - int d = idx % HD; - idx /= HD; - int nh = idx % NH; - idx /= NH; - int c = idx % 3; - idx /= 3; - int n = idx % N; - int b = idx / N; - - int inp_idx; - int nh_total = NH + 2 * (NH / replicate_factor); - if (c == 0) { - inp_idx = b * N * nh_total * HD + n * nh_total * HD + 0 * NH * HD + nh * HD + d; - } else if (c == 1) { - inp_idx = b * N * nh_total * HD + n * nh_total * HD + 1 * NH * HD + (nh / replicate_factor) * HD + d; - } else { - inp_idx = b * N * nh_total * HD + n * nh_total * HD + (NH * HD + (NH / replicate_factor) * HD) + (nh / replicate_factor) * HD + d; - } - - replicated_qkv[idx_flat] = __ldcs(&gqa_qkv[inp_idx]); -#endif } -// TODO: update after CPU kernel // kernel launchers void repkv_backward1(floatX* dinp, const floatX* inp, const floatX* doutp, const int B, const int T, const int NH, const int NH_KV, const int d, int block_size) { @@ -116,7 +86,6 @@ void repkv_backward1(floatX* dinp, const floatX* inp, const floatX* doutp, cudaCheck(cudaGetLastError()); } -// TODO: update after CPU kernel // kernel dispatcher void repkv_backward(int kernel_num, floatX* dinp, const floatX* inp, const floatX* doutp, @@ -132,6 +101,7 @@ void repkv_backward(int kernel_num, } } +// TODO: update void log_mat(float *inp, int B, int T, int C, int hd, int qh, int kh, int vh, char *title) { printf("%s -----\n", title); @@ -169,12 +139,11 @@ void log_mat(float *inp, int B, int T, int C, int hd, int qh, int kh, int vh, ch printf("\n"); } -// TODO: update after CPU kernel // tester int main(int argc, char **argv) { srand(0); -#ifndef DEBUG +#ifdef DEBUG int B = 1; int T = 2; int hd = 3; // head dim From 42d09e8732b1e0c7ac0b6b72f16e1ff49171accc Mon Sep 17 00:00:00 2001 From: Insop Song Date: Sun, 22 Sep 2024 00:26:29 -0700 Subject: [PATCH 16/63] minor clean up --- dev/cuda/repkv_backward.cu | 4 ---- 1 file changed, 4 deletions(-) diff --git a/dev/cuda/repkv_backward.cu b/dev/cuda/repkv_backward.cu index ec03fce41..fdf291244 100644 --- a/dev/cuda/repkv_backward.cu +++ b/dev/cuda/repkv_backward.cu @@ -181,9 +181,6 @@ int main(int argc, char **argv) { cudaCheck(cudaMalloc(&d_inp, B * T * Cin * sizeof(float))); cudaCheck(cudaMalloc(&d_doutp, B * T * Cout * sizeof(float))); - // cudaCheck(memcpy_convert(d_dinp, inp, B * T * Cin)); - // cudaCheck(memcpy_convert(d_doutp, doutp, B * T * Cout)); - // read kernel_num from command line int kernel_num = 1; if (argc > 1) { @@ -198,7 +195,6 @@ int main(int argc, char **argv) { log_mat(dinp, B, T, Cin, hd, qh, kh, vh, "dinp"); - // TODO: update after CPU kernel // check the correctness of the kernel at all block sizes int block_sizes[] = {32, 64, 128, 256, 512, 1024}; cudaCheck(cudaMemcpy(d_dinp, inp, B * T * Cin * sizeof(float), cudaMemcpyHostToDevice)); From fcc3466b2f660a7957f22a2a728ec04ee3d5229c Mon Sep 17 00:00:00 2001 From: Insop Song Date: Sun, 22 Sep 2024 09:05:56 -0700 Subject: [PATCH 17/63] Add minor change --- dev/cuda/repkv_backward.cu | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/dev/cuda/repkv_backward.cu b/dev/cuda/repkv_backward.cu index fdf291244..bbb57c6d1 100644 --- a/dev/cuda/repkv_backward.cu +++ b/dev/cuda/repkv_backward.cu @@ -188,12 +188,16 @@ int main(int argc, char **argv) { } printf("Using kernel %d\n", kernel_num); +#ifdef DEBUG log_mat(doutp, B, T, Cout, hd, qh, nrep*kh, nrep*vh, "doutp"); +#endif // DEBUG // CPU reference calculate repkv_backward_cpu(dinp, inp, doutp, B, T, Cout, hd, qh, kh, vh); +#ifdef DEBUG log_mat(dinp, B, T, Cin, hd, qh, kh, vh, "dinp"); +#endif // DEBUG // check the correctness of the kernel at all block sizes int block_sizes[] = {32, 64, 128, 256, 512, 1024}; From de9c8170e5fc6ee2974bfe703d3994541e73cc89 Mon Sep 17 00:00:00 2001 From: Insop Song Date: Mon, 23 Sep 2024 20:52:50 -0700 Subject: [PATCH 18/63] wip --- dev/cuda/Makefile | 2 +- dev/cuda/repkv_backward.cu | 47 +++++++++++++++++++++++++++++++++++++- 2 files changed, 47 insertions(+), 2 deletions(-) diff --git a/dev/cuda/Makefile b/dev/cuda/Makefile index 0fc69728a..9661571e9 100644 --- a/dev/cuda/Makefile +++ b/dev/cuda/Makefile @@ -19,7 +19,7 @@ endif ifeq ($(GPU_COMPUTE_CAPABILITY),) # set to defaults if: make GPU_COMPUTE_CAPABILITY= CFLAGS = -O3 --use_fast_math else - CFLAGS = -O3 --use_fast_math --generate-code arch=compute_$(GPU_COMPUTE_CAPABILITY),code=[compute_$(GPU_COMPUTE_CAPABILITY),sm_$(GPU_COMPUTE_CAPABILITY)] + CFLAGS = -DDEBUG -O3 --use_fast_math --generate-code arch=compute_$(GPU_COMPUTE_CAPABILITY),code=[compute_$(GPU_COMPUTE_CAPABILITY),sm_$(GPU_COMPUTE_CAPABILITY)] endif NVCCFLAGS = -lcublas -lcublasLt -std=c++17 diff --git a/dev/cuda/repkv_backward.cu b/dev/cuda/repkv_backward.cu index bbb57c6d1..e43b121e9 100644 --- a/dev/cuda/repkv_backward.cu +++ b/dev/cuda/repkv_backward.cu @@ -72,8 +72,53 @@ void repkv_backward_cpu(float* dinp, const float* inp, const float* doutp, __global__ void repkv_backward_kernel1(floatX* dinp, const floatX* inp, const floatX* doutp, int B, int N, int NH, int replicate_factor, int HD) { + // we have a single tensor doutp of shapae of (B, N 3 * NH * HD) + // we want to reduce sum (for K and V) into (B, N, (NH + 2*(NH/replicate_factor)) * HD) + int idx = blockIdx.x * blockDim.x + threadIdx.x; - // TODO: update after CPU kernel + // ?? + if (idx >= B * N * 3 * NH * HD) { return;} + // ?? + int doutp_idx = idx; // keep backp + + // decode the doutp index + int d = idx % HD; + idx /= HD; + int nh = idx % NH; + idx /= NH; + int c = idx % 3; + idx /= 3; + int n = idx % N; + int b = idx / N; + + int dinp_idx; + // int nh_total = NH * 3; + int nh_total = NH + 2 * (NH / replicate_factor); + + if (c == 0) { + dinp_idx = b * N * nh_total * HD + n * nh_total * HD + 0 * NH * HD * nh * HD + d; + dinp[dinp_idx] = __ldca(&doutp[doutp_idx]); + } else if (c == 1) { + dinp_idx = b * N * nh_total * HD + n * nh_total * HD + 1 * NH * HD + (nh / replicate_factor) * HD + d; + // float reduced_sum = 0; + // if (doutp_idx % replicate_factor == 0) { + // for (int i = doutp_idx; i < doutp_idx+replicate_factor; i++) + // reduced_sum += __ldcs(&doutp[i]); + // dinp[dinp_idx] = reduced_sum; + // } + // ?? + dinp[dinp_idx] = __ldca(&doutp[doutp_idx]); + } else { + dinp_idx = b * N * nh_total * HD + n * nh_total * HD + (NH * HD + (NH / replicate_factor) * HD) + (nh / replicate_factor) * HD + d; + // float reduced_sum = 0; + // if (doutp_idx % replicate_factor == 0) { + // for (int i = doutp_idx; i < doutp_idx + replicate_factor; i++) + // reduced_sum += __ldcs(&doutp[i]); + // dinp[dinp_idx] = reduced_sum; + // } + // ?? + dinp[dinp_idx] = __ldca(&doutp[doutp_idx]); + } } // kernel launchers From 76b40e43ba7b1078f359d863f18a034168aefe3d Mon Sep 17 00:00:00 2001 From: Andrej Karpathy Date: Tue, 24 Sep 2024 17:30:51 +0000 Subject: [PATCH 19/63] integrate the repkv kernel with minor changes. use the bt4c buffer for the replication. rope is next --- llmc/repkv.cuh | 63 +++++++++++++++++++++++++++++++++++++++++++ train_llama3.cu | 72 +++++++++++++++++++++++++++---------------------- 2 files changed, 103 insertions(+), 32 deletions(-) create mode 100644 llmc/repkv.cuh diff --git a/llmc/repkv.cuh b/llmc/repkv.cuh new file mode 100644 index 000000000..055f15ea7 --- /dev/null +++ b/llmc/repkv.cuh @@ -0,0 +1,63 @@ +/* +Layer that takes a QKV tensor of shape (B, T, C) and replicates the K,V +some number of times. For example, if B=4, T=64, C=6144, and we have that: +- head dimension (hd) is 128 channels +- query heads: 32 +- key heads: 8 +- value heads: 8 +- so number of heads = 32 + 8 + 8 = 48, each of 128 channels, total of 6144 channels +We want to replicate the key/value vectors 4X, so that we get: +32 + 32 + 32 = 96 query, key, value heads, each of 128 channels, total of 12288 channels +Each of these vectors should be replicated by simple copying/concat 4X times. + +See dev/cuda/repkv.cu for correctness and performance reference +block_size 128 seems fastest on H100 +*/ + +#include "cuda_common.h" + +__global__ void repkv_forward_kernel1(floatX* replicated_qkv, + const floatX* gqa_qkv, + int B, int N, int NH, int replicate_factor, int HD) { + // we have a single tensor gqa_qkv of shape (B, N, (NH + 2*(NH/replicate_factor)) * HD) + // we want to replicate it into (B, N, 3 * NH * HD) + int idx = blockIdx.x * blockDim.x + threadIdx.x; + if (idx >= B * N * 3 * NH * HD) { return; } + int idx_flat = idx; // keep backup + + // decode the output index + int d = idx % HD; + idx /= HD; + int nh = idx % NH; + idx /= NH; + int c = idx % 3; + idx /= 3; + int n = idx % N; + int b = idx / N; + + int inp_idx; + int nh_total = NH + 2 * (NH / replicate_factor); + if (c == 0) { + inp_idx = b * N * nh_total * HD + n * nh_total * HD + 0 * NH * HD + nh * HD + d; + } else if (c == 1) { + inp_idx = b * N * nh_total * HD + n * nh_total * HD + 1 * NH * HD + (nh / replicate_factor) * HD + d; + } else { + inp_idx = b * N * nh_total * HD + n * nh_total * HD + (NH * HD + (NH / replicate_factor) * HD) + (nh / replicate_factor) * HD + d; + } + + replicated_qkv[idx_flat] = __ldcs(&gqa_qkv[inp_idx]); +} + +void repkv_forward(floatX* out, const floatX* inp, int B, int T, int NH, int NH_KV, int HD) { + // NH = number of query heads, NH_KV = number of key and value heads, HD = head dimension + const int block_size = 128; + int total_threads = B * T * (3 * NH) * HD; // one thread per output element + int num_blocks = CEIL_DIV(total_threads, block_size); + int replicate_factor = NH / NH_KV; + if (replicate_factor > 1) { + repkv_forward_kernel1<<>>(out, inp, B, T, NH, replicate_factor, HD); + } else { + cudaMemcpy(out, inp, total_threads * sizeof(floatX), cudaMemcpyDeviceToDevice); + } + cudaCheck(cudaGetLastError()); +} diff --git a/train_llama3.cu b/train_llama3.cu index da5b86c44..a83561941 100644 --- a/train_llama3.cu +++ b/train_llama3.cu @@ -63,6 +63,8 @@ GPT-2 Transformer Neural Net training loop. See README.md for usage. #include "llmc/adamw.cuh" // defines: global_norm_squared #include "llmc/global_norm.cuh" +// defines: repkv_forward +#include "llmc/repkv.cuh" // ----------- Multi-GPU support ----------- // defines: ncclFloatX, ncclCheck, MultiGpuConfig, ShardInfo // defines: printf0, multi_gpu_config @@ -133,9 +135,9 @@ void fill_in_parameter_sizes(size_t* param_sizes, size_t* param_sizeof, GPT2Conf // calculation following the .py code inside CausalSelfAttention // we have to calculate the number of channels in the QKV projection size_t n_head = config.num_heads; - size_t n_kn_head = config.num_kv_heads; + size_t n_kv_head = config.num_kv_heads; size_t hd = C / n_head; // head dimension - size_t qkv_channels = (n_head + 2*n_kn_head) * hd; // Q, K, V channels + size_t qkv_channels = (n_head + 2*n_kv_head) * hd; // Q, K, V channels // calculation following the .py code inside MLP // we have to calculate the number of channels in the SwiGLU projections c_fc + c_fc2 size_t hidden_dim = 4 * C; @@ -244,10 +246,10 @@ void fill_in_activation_sizes(const ActivationTensors* data, TensorSpec (&tensor const size_t L = config.num_layers; const size_t NH = config.num_heads; const size_t C = config.channels; - const size_t n_head = config.num_heads; - const size_t n_kn_head = config.num_kv_heads; - const size_t hd = C / n_head; // head dimension - const size_t qkv_channels = (n_head + 2*n_kn_head) * hd; // Q, K, V channels + const size_t n_head = config.num_heads; // num query heads + const size_t n_kv_head = config.num_kv_heads; // num key and value heads + const size_t hd = C / n_head; // the size of each head + const size_t qkv_channels = (n_head + 2*n_kv_head) * hd; // Q, K, V channels tensors[0] = TENSOR_SPEC(data->encoded, B * T * C); // if recompute >= 1 then we will recompute the layernorm forward activation during backward pass @@ -603,9 +605,9 @@ void gpt2_forward(GPT2 *model, const int* inputs, size_t B, size_t T) { const size_t NH = model->config.num_heads; const size_t C = model->config.channels; const size_t n_head = model->config.num_heads; - const size_t n_kn_head = model->config.num_kv_heads; + const size_t n_kv_head = model->config.num_kv_heads; const size_t hd = C / n_head; // head dimension - const size_t qkv_channels = (n_head + 2*n_kn_head) * hd; // Q, K, V channels + const size_t qkv_channels = (n_head + 2*n_kv_head) * hd; // Q, K, V channels // validate B,T are not larger than the values used at initialisation // (smaller B,T are okay for inference only) @@ -627,18 +629,6 @@ void gpt2_forward(GPT2 *model, const int* inputs, size_t B, size_t T) { // first layernorm isn't fused rmsnorm_forward((model->recompute < 2) ? acts.ln1 : acts.lnf, acts.ln1_rstd, acts.encoded, params.ln1w, B, T, C, main_stream); - // ------------------------------------------------------------------------ - // DEBUGGING: we only work until this point right now, so exit here - // transfer the first 32 elements to CPU and print them - floatX* output = (model->recompute < 2) ? acts.ln1 : acts.lnf; - floatX* cpu = (floatX*)mallocCheck(32 * sizeof(floatX)); - cudaCheck(cudaMemcpy(cpu, output, 32 * sizeof(floatX), cudaMemcpyDeviceToHost)); - for (int i = 0; i < 32; i++) { - printf("cpu[%d] = %.8f\n", i, (float) cpu[i]); - } - exit(0); - // ------------------------------------------------------------------------ - for (int l = 0; l < L; l++) { NvtxRange layer_range("Layer", l); @@ -670,21 +660,39 @@ void gpt2_forward(GPT2 *model, const int* inputs, size_t B, size_t T) { floatX* l_fch_gelu = (model->recompute < 1) ? acts.fch_gelu + l * B * T * 4*C : acts.fch_gelu; floatX* l_residual3 = acts.residual3 + l * B * T * C; floatX* scratch = (floatX*)acts.output; // used for non-cudnn attention, fcproj, attproj, etc. + floatX* qkv_rep_scratch = (floatX*)acts.scratch_bt4c; // we can use the BT4C scratch for qkv replication - // now start the block forward pass + // Attention block + // The input l_ln1 now holds the (already layernormed) input #ifdef ENABLE_CUDNN - matmul_forward_cublaslt(l_qkvr, l_ln1, l_qkvw, l_qkvb, B, T, C, qkv_channels, main_stream); - float* l_att = (float*)acts.att + l * B * NH * T; // cuDNN needs a smaller FP32 tensor - attention_forward_cudnn(l_atty, (float*)l_att, l_qkvr, B, T, NH, C, main_stream); + printf("cuDNN path TODO\n"); exit(0); + matmul_forward_cublaslt(l_qkvr, l_ln1, l_qkvw, l_qkvb, B, T, C, qkv_channels, main_stream); + float* l_att = (float*)acts.att + l * B * NH * T; // cuDNN needs a smaller FP32 tensor + attention_forward_cudnn(l_atty, (float*)l_att, l_qkvr, B, T, NH, C, main_stream); #else - floatX* l_att = acts.att + l * B * NH * T * T; - if (T != model->seq_len) { // unused parts of attention buffer must be zeroed (T-dependent) - cudaCheck(cudaMemset(l_att, 0, B * NH * T * T * sizeof(floatX))); - } - // these are only needed as scratchpads for the forward pass, but - // need not be stored for backward - matmul_forward_cublaslt(scratch, l_ln1, l_qkvw, l_qkvb, B, T, C, qkv_channels, main_stream); - attention_forward(l_atty, l_qkvr, l_att, scratch, B, T, C, NH, main_stream); + // unused parts of attention buffer must be zeroed (T-dependent) + floatX* l_att = acts.att + l * B * NH * T * T; + if (T != model->seq_len) { cudaCheck(cudaMemset(l_att, 0, B * NH * T * T * sizeof(floatX))); } + // 1) projection to QKV vectors (note k,v may be fewer heads than q) + matmul_forward_cublaslt(scratch, l_ln1, l_qkvw, l_qkvb, B, T, C, qkv_channels, main_stream); + // 2) replicate k,v so that all of q,k,v have the same number of heads. done for simplicity, for now + repkv_forward(qkv_rep_scratch, scratch, B, T, n_head, n_kv_head, hd); + + // ------------------------------------------------------------------------ + // DEBUGGING: we only work until this point right now, so exit here + // transfer the first 32 elements to CPU and print them + floatX* output = qkv_rep_scratch; + floatX* cpu = (floatX*)mallocCheck(32 * sizeof(floatX)); + cudaCheck(cudaMemcpy(cpu, output, 32 * sizeof(floatX), cudaMemcpyDeviceToHost)); + for (int i = 0; i < 32; i++) { + printf("q[%d] = %.8f\n", i, (float) cpu[i]); + } + exit(0); + // ------------------------------------------------------------------------ + + // 3) apply RoPE to q,k + // 4) attention: att <- softmax(qk^T)v + attention_forward(l_atty, l_qkvr, l_att, scratch, B, T, C, NH, main_stream); #endif matmul_forward_cublaslt(scratch, l_atty, l_attprojw, l_attprojb, B, T, C, C, main_stream); From 026e4ed323fe87004f3a5af6c95e17894cfc5032 Mon Sep 17 00:00:00 2001 From: Andrej Karpathy Date: Tue, 24 Sep 2024 23:52:16 +0000 Subject: [PATCH 20/63] add RoPE PyTorch and C reference code --- dev/cbridge/README.md | 13 +++ dev/cbridge/rope.c | 246 ++++++++++++++++++++++++++++++++++++++++++ dev/cbridge/rope.py | 169 +++++++++++++++++++++++++++++ 3 files changed, 428 insertions(+) create mode 100644 dev/cbridge/README.md create mode 100644 dev/cbridge/rope.c create mode 100644 dev/cbridge/rope.py diff --git a/dev/cbridge/README.md b/dev/cbridge/README.md new file mode 100644 index 000000000..00b751487 --- /dev/null +++ b/dev/cbridge/README.md @@ -0,0 +1,13 @@ +# cbridge + +We'll use this directory for the PyTorch -> C bridge. So we have some PyTorch code and we'd like to have the equivalent C implementation (usually that one in turn serves as reference for the CUDA kernels later). + +For starters we have RoPE. E.g. generate the reference with PyTorch and then match it in C: + +```bash +python rope.py +gcc -o rope rope.c -lm +./rope +``` + +The .py file writes a `robe.bin` file with the intermediate tensors. diff --git a/dev/cbridge/rope.c b/dev/cbridge/rope.c new file mode 100644 index 000000000..3677fc6cb --- /dev/null +++ b/dev/cbridge/rope.c @@ -0,0 +1,246 @@ +/* +Our goal here is to load the .bin files generated by rope.py and match +the implementation in C and get the same results as in rope.py. + +Compile and run simply with: + +gcc -o rope rope.c -lm +./rope +*/ + +#include +#include +#include +#include + +// ---------------------------------------------------------------------------- +// a few utils for safety +extern inline void fread_check(void *ptr, size_t size, size_t nmemb, FILE *stream, const char *file, int line) { + size_t result = fread(ptr, size, nmemb, stream); + if (result != nmemb) { + if (feof(stream)) { + fprintf(stderr, "Error: Unexpected end of file at %s:%d\n", file, line); + } else if (ferror(stream)) { + fprintf(stderr, "Error: File read error at %s:%d\n", file, line); + } else { + fprintf(stderr, "Error: Partial read at %s:%d. Expected %zu elements, read %zu\n", + file, line, nmemb, result); + } + fprintf(stderr, "Error details:\n"); + fprintf(stderr, " File: %s\n", file); + fprintf(stderr, " Line: %d\n", line); + fprintf(stderr, " Expected elements: %zu\n", nmemb); + fprintf(stderr, " Read elements: %zu\n", result); + exit(EXIT_FAILURE); + } +} +#define freadCheck(ptr, size, nmemb, stream) fread_check(ptr, size, nmemb, stream, __FILE__, __LINE__) + +extern inline void *malloc_check(size_t size, const char *file, int line) { + void *ptr = malloc(size); + if (ptr == NULL) { + fprintf(stderr, "Error: Memory allocation failed at %s:%d\n", file, line); + fprintf(stderr, "Error details:\n"); + fprintf(stderr, " File: %s\n", file); + fprintf(stderr, " Line: %d\n", line); + fprintf(stderr, " Size: %zu bytes\n", size); + exit(EXIT_FAILURE); + } + return ptr; +} + +#define mallocCheck(size) malloc_check(size, __FILE__, __LINE__) + +int compare_arrays(const float *arr1, const float *arr2, size_t size, const char *name, float epsilon) { + for (size_t i = 0; i < size; i++) { + // print 10 elements that are equally spaced out, for qualitative check + if (i % (size / 10) == 0) { + printf("arr1[%zu] = %f, arr2[%zu] = %f\n", i, arr1[i], i, arr2[i]); + } + if (fabsf(arr1[i] - arr2[i]) > epsilon) { + printf("Error: %s[%zu] = %f, expected %f (diff: %f)\n", + name, i, arr1[i], arr2[i], fabsf(arr1[i] - arr2[i])); + return 0; + } + } + printf("OK: %s\n", name); + return 1; +} + +// ---------------------------------------------------------------------------- +// all the functions we need + +void precompute_freqs_cis(float *freqs_cis, int dim, int end, float theta, int use_scaled) { + // same as precompute_freqs_cis_real in rope.py + for (int i = 0; i < dim / 2; i++) { + + // calculate the frequency for the (i, i+1)th dimension + float freq = 1.0f / powf(theta, (float)(2 * i) / dim); + if (use_scaled) { + const int scale_factor = 8; + const int low_freq_factor = 1; + const int high_freq_factor = 4; + const int old_context_len = 8192; // original llama3 length + const float low_freq_wavelen = (float)old_context_len / low_freq_factor; + const float high_freq_wavelen = (float)old_context_len / high_freq_factor; + float wavelen = 2.0f * M_PI / freq; + if (wavelen < high_freq_wavelen) { + // skip; keep freq as is + } else if (wavelen > low_freq_wavelen) { + // scale down by scale_factor + freq /= scale_factor; + } else { + // smooth transition between scaled and unscaled + float smooth = ((float)old_context_len / wavelen - low_freq_factor) / (high_freq_factor - low_freq_factor); + freq = (1.0f - smooth) * freq / scale_factor + smooth * freq; + } + } + + // iterate over all time steps, calculate the angle, and store the cos/sin + for (int t = 0; t < end; t++) { + float angle = (float)t * freq; + freqs_cis[t * dim + 2 * i] = cosf(angle); // real part + freqs_cis[t * dim + 2 * i + 1] = sinf(angle); // imaginary part + } + } +} + +void apply_rotary_emb_forward(float *out, const float *inp, const float *freqs_cis, int B, int T, int n_head, int head_dim) { + // same as apply_rotary_emb_real in rope.py + for (int b = 0; b < B; b++) { + for (int t = 0; t < T; t++) { + int idx_bt = b * (T * n_head * head_dim) + t * (n_head * head_dim); + for (int h = 0; h < n_head; h++) { + int idx_bth = idx_bt + h * head_dim; + for (int d = 0; d < head_dim / 2; d++) { + // fetch a tuple of activations, which we imagine as a complex number + int idx = idx_bth + 2 * d; + float x_real = inp[idx]; + float x_imag = inp[idx + 1]; + // fetch the angle from freqs_cis + int freqs_idx = t * head_dim + 2 * d; + float freqs_cos = freqs_cis[freqs_idx]; + float freqs_sin = freqs_cis[freqs_idx + 1]; + // apply the rotation + out[idx] = x_real * freqs_cos - x_imag * freqs_sin; + out[idx + 1] = x_real * freqs_sin + x_imag * freqs_cos; + } + } + } + } +} + +void apply_rotary_emb_backward(float *dinp, const float *dout, const float *inp, const float *freqs_cis, int B, int T, int n_head, int head_dim) { + // backward pass of the RoPE embedding + for (int b = 0; b < B; b++) { + for (int t = 0; t < T; t++) { + int idx_bt = b * (T * n_head * head_dim) + t * (n_head * head_dim); + for (int h = 0; h < n_head; h++) { + int idx_bth = idx_bt + h * head_dim; + for (int d = 0; d < head_dim / 2; d++) { + // fetch the angle from freqs_cis + int freqs_idx = t * head_dim + 2 * d; + float freqs_cos = freqs_cis[freqs_idx]; + float freqs_sin = freqs_cis[freqs_idx + 1]; + // and the input index we'll be updating + int idx = idx_bth + 2 * d; + // backward pass is simple because freqs_cis is just scaling by a constant + dinp[idx] += dout[idx] * freqs_cos + dout[idx + 1] * freqs_sin; + dinp[idx + 1] += -dout[idx] * freqs_sin + dout[idx + 1] * freqs_cos; + } + } + } + } +} + +// ---------------------------------------------------------------------------- + +int main() { + + // load the .bin file + FILE *file = fopen("rope.bin", "rb"); + if (file == NULL) { + printf("Error: Could not open file.\n"); + return 1; + } + // read the header + int int_header[16]; + float float_header[16]; + freadCheck(int_header, sizeof(int), 16, file); + freadCheck(float_header, sizeof(float), 16, file); + // check the magic number + if (int_header[0] != 20240924) { + printf("Error: Invalid magic number.\n"); + fclose(file); + return 1; + } + // extract the hyperparameters + int B = int_header[1]; + int T = int_header[2]; + int n_embd = int_header[3]; + int n_head = int_header[4]; + int use_scaled_rope = int_header[5]; + float rope_theta = float_header[0]; + int head_dim = n_embd / n_head; + // read the inputs + float *inp = (float *)mallocCheck(B * T * n_head * head_dim * sizeof(float)); + freadCheck(inp, sizeof(float), B * T * n_head * head_dim, file); + // read the freqs_cis + float *freqs_cis_target = (float *)mallocCheck(T * head_dim * sizeof(float)); + freadCheck(freqs_cis_target, sizeof(float), T * head_dim, file); + // read the output + float *out_target = (float *)mallocCheck(B * T * n_head * head_dim * sizeof(float)); + freadCheck(out_target, sizeof(float), B * T * n_head * head_dim, file); + // read the weights for the loss function + float *wei = (float *)mallocCheck(B * T * n_head * head_dim * sizeof(float)); + freadCheck(wei, sizeof(float), B * T * n_head * head_dim, file); + // read the input gradients + float *inp_grad_target = (float *)mallocCheck(B * T * n_head * head_dim * sizeof(float)); + freadCheck(inp_grad_target, sizeof(float), B * T * n_head * head_dim, file); + // ensure we exactly exhausted the file + long current_position = ftell(file); + // Get the file size + fseek(file, 0, SEEK_END); + long file_size = ftell(file); + // check if we read the whole file + if (current_position != file_size) { + printf("Error: File was not read properly; %ld bytes left unread.\n", file_size - current_position); + fclose(file); + return 1; + } + fclose(file); + + // print the hyperparameters + printf("B: %d, T: %d, n_embd: %d, n_head: %d, use_scaled_rope: %d, rope_theta: %f\n", + B, T, n_embd, n_head, use_scaled_rope, rope_theta); + + // Step 1) Calculate freqs_cis in C and compare with the Python one + float *freqs_cis = (float *)mallocCheck(T * head_dim * sizeof(float)); + precompute_freqs_cis(freqs_cis, head_dim, T, rope_theta, use_scaled_rope); + if (!compare_arrays(freqs_cis, freqs_cis_target, T * head_dim, "freqs_cis", 1e-6f)) { return 1; } + + // Step 2) Apply the RoPE embedding in C and compare with the Python one + float *out = (float *)mallocCheck(B * T * n_head * head_dim * sizeof(float)); + apply_rotary_emb_forward(out, inp, freqs_cis, B, T, n_head, head_dim); + if (!compare_arrays(out, out_target, B * T * n_head * head_dim, "out", 1e-6f)) { return 1; } + + // Step 3) Calculate the loss and gradients in C and compare with the Python one + float *dout = wei; // wei is dout because the loss is just a dot product of out and wei + float *dinp = (float *)mallocCheck(B * T * n_head * head_dim * sizeof(float)); + apply_rotary_emb_backward(dinp, dout, inp, freqs_cis, B, T, n_head, head_dim); + if (!compare_arrays(dinp, inp_grad_target, B * T * n_head * head_dim, "dinp", 1e-6f)) { return 1; } + + printf("✅ ALL OK\n"); + + // clean up + free(inp); + free(freqs_cis_target); + free(out_target); + free(wei); + free(inp_grad_target); + free(freqs_cis); + free(out); + free(dinp); + + return 0; +} diff --git a/dev/cbridge/rope.py b/dev/cbridge/rope.py new file mode 100644 index 000000000..8a7f90d2f --- /dev/null +++ b/dev/cbridge/rope.py @@ -0,0 +1,169 @@ +""" +Minimal example for developing RoPE. +Basically we cherry-pick / copy paste the critical portions from train_llama3.py +This script then does forward/back and writes everything to file so we can +develop the CPU version, and eventually the GPU kernel as well. +""" + +import math +import torch +import numpy as np + +# ----------------------------------------------------------------------------- +# hyperparameters + +# Llama 3.1 8B config, simplified a tiny bit by removing spurious parameters +class LlamaConfig: + version: str = "3.1" + block_size: int = 8192 + vocab_size: int = 128256 + n_layer: int = 32 + n_head: int = 32 + n_kv_head: int = 8 + n_embd: int = 4096 + rope_theta: float = 500000.0 + use_scaled_rope: bool = True + +config = LlamaConfig() + +# ----------------------------------------------------------------------------- +# freqs_cis + +def apply_scaling(freqs: torch.Tensor): + scale_factor = 8 + low_freq_factor = 1 + high_freq_factor = 4 + old_context_len = 8192 # original llama3 length + low_freq_wavelen = old_context_len / low_freq_factor + high_freq_wavelen = old_context_len / high_freq_factor + new_freqs = [] + for freq in freqs: + wavelen = 2 * math.pi / freq + if wavelen < high_freq_wavelen: + new_freqs.append(freq) + elif wavelen > low_freq_wavelen: + new_freqs.append(freq / scale_factor) + else: + assert low_freq_wavelen != high_freq_wavelen + smooth = (old_context_len / wavelen - low_freq_factor) / (high_freq_factor - low_freq_factor) + new_freqs.append((1 - smooth) * freq / scale_factor + smooth * freq) + return torch.tensor(new_freqs, dtype=freqs.dtype, device=freqs.device) + +def precompute_freqs_cis_complex(dim: int, end: int, theta: float = 10000.0, use_scaled: bool = False): + """ the complex valued version of this """ + freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim)) + t = torch.arange(end, device=freqs.device, dtype=torch.float32) + if use_scaled: + freqs = apply_scaling(freqs) + freqs = torch.outer(t, freqs) + freqs_cis = torch.polar(torch.ones_like(freqs), freqs) # complex64 + return freqs_cis + +def precompute_freqs_cis_real(dim: int, end: int, theta: float = 10000.0, use_scaled: bool = False): + """ the much simpler real-valued-only version """ + freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim)) + if use_scaled: + freqs = apply_scaling(freqs) + t = torch.arange(end, device=freqs.device, dtype=torch.float32) + freqs = torch.outer(t, freqs) + freqs_cos = torch.cos(freqs) # real part (end, dim // 2) + freqs_sin = torch.sin(freqs) # imaginary part (end, dim // 2) + freqs_cis = torch.stack([freqs_cos, freqs_sin], dim=-1) # (end, dim // 2, 2) + return freqs_cis + +# ----------------------------------------------------------------------------- +# RoPE forward pass + +def reshape_for_broadcast(freqs_cis: torch.Tensor, x: torch.Tensor): + ndim = x.ndim + assert 0 <= 1 < ndim + assert freqs_cis.shape == (x.shape[1], x.shape[-1]) + shape = [d if i == 1 or i == ndim - 1 else 1 for i, d in enumerate(x.shape)] + return freqs_cis.view(*shape) + +def apply_rotary_emb_complex(x: torch.Tensor, freqs_cis: torch.Tensor): + x_ = torch.view_as_complex(x.float().reshape(*x.shape[:-1], -1, 2)) + freqs_cis = reshape_for_broadcast(freqs_cis, x_) + x_out = torch.view_as_real(x_ * freqs_cis).flatten(3) + return x_out.type_as(x) + +def apply_rotary_emb_real(x: torch.Tensor, freqs_cis: torch.Tensor): + xf = x.float() + x_real, x_imag = xf[..., ::2], xf[..., 1::2] + freqs_cos, freqs_sin = freqs_cis[..., 0], freqs_cis[..., 1] + freqs_cos = reshape_for_broadcast(freqs_cos, x_real) + freqs_sin = reshape_for_broadcast(freqs_sin, x_real) + x_out_real = x_real * freqs_cos - x_imag * freqs_sin + x_out_imag = x_real * freqs_sin + x_imag * freqs_cos + x_out = torch.stack([x_out_real, x_out_imag], dim=-1) + return x_out.flatten(3).type_as(x) + +# ----------------------------------------------------------------------------- +# forward, backward, save a little example + +B = 4 +T = 64 +head_dim = config.n_embd // config.n_head + +# input here is (B, T, n_head, head_dim) +# for example in Llama 3.1 8B, this could be B=4, T=64, n_head=32, head_dim=128 +inp = torch.randn(B, T, config.n_head, head_dim, dtype=torch.float32) +inp.requires_grad = True + +# compare the two versions of RoPE +# 1. out_complex is the original PyTorch version that I think is way too complex +# 2. out_real is the simplified version that I think is correct +# ----------------------------------------------------------------------------- +# (1) +freqs_cis = precompute_freqs_cis_complex( + head_dim, + config.block_size * 2, + config.rope_theta, + config.use_scaled_rope, +) +freqs_cis = freqs_cis[:T] # (T, head_dim // 2) of complex64 +out_complex = apply_rotary_emb_complex(inp, freqs_cis) + +# ----------------------------------------------------------------------------- +# (2) +freqs_cis = precompute_freqs_cis_real( + head_dim, + config.block_size * 2, + config.rope_theta, + config.use_scaled_rope, +) +freqs_cis = freqs_cis[:T] # (T, head_dim // 2, 2) +out_real = apply_rotary_emb_real(inp, freqs_cis) +# ----------------------------------------------------------------------------- +assert torch.allclose(out_complex, out_real, atol=1e-6) +print("RoPE simplified version OK") +out = out_real # ok to use this one +# ----------------------------------------------------------------------------- + +# calculate the loss and gradients +wei = torch.randn_like(out, dtype=torch.float32) +loss = (out * wei).sum() +loss.backward() + +# save to .bin file so we can check correctness in C land +int_header = np.zeros(16, dtype=np.int32) # for ints +float_header = np.zeros(16, dtype=np.float32) # for floats +int_header[0] = 20240924 # magic number +int_header[1] = B +int_header[2] = T +int_header[3] = config.n_embd +int_header[4] = config.n_head +int_header[5] = config.use_scaled_rope +float_header[0] = config.rope_theta + +# write the hyperparameters, inputs, output, and input gradients to file +results_file = "rope.bin" +with open(results_file, "wb") as f: + f.write(int_header.tobytes()) # 16 int32 + f.write(float_header.tobytes()) # 16 float32 + f.write(inp.detach().cpu().numpy().tobytes()) # B * T * n_head * head_dim + f.write(freqs_cis.detach().cpu().numpy().tobytes()) # T * head_dim // 2 * 2 + f.write(out.detach().cpu().numpy().tobytes()) # B * T * n_head * head_dim + f.write(wei.detach().cpu().numpy().tobytes()) # B * T * n_head * head_dim + f.write(inp.grad.detach().cpu().numpy().tobytes()) # B * T * n_head * head_dim +print("Saved results to %s" % results_file) From 2ebf8f6b8a1259f4a6003a5d91d16532991ad5b9 Mon Sep 17 00:00:00 2001 From: Aleksa Gordic Date: Wed, 25 Sep 2024 10:37:53 -0700 Subject: [PATCH 21/63] Add rmsnorm fused kernel --- llmc/layernorm.cuh | 95 ++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 95 insertions(+) diff --git a/llmc/layernorm.cuh b/llmc/layernorm.cuh index 1387b11ab..41f9ec6d3 100644 --- a/llmc/layernorm.cuh +++ b/llmc/layernorm.cuh @@ -278,6 +278,77 @@ __global__ void fused_residual_forward_kernel5(floatX* residual, floatX* normed, } } +__global__ void fused_residual_rmsnorm_forward_kernel5(floatX* residual, floatX* normed, float* rrms, + const floatX* inp1, const floatX* inp2, + const floatX* weight, + int N, int C) { + assert(blockDim.x == WARP_SIZE); + + // load weights and biases into shared memory + // do this before we allow any threads to exit! + extern __shared__ char* params[]; + // load128/store128 sometimes generated multiple instructions when the types here were floatX*, so + // let's keep everything as x128 + x128* s_weight = reinterpret_cast(params); + x128* s_res = reinterpret_cast(params) + ((1 + threadIdx.y) * C / x128::size); + + int sidx = (threadIdx.x + WARP_SIZE * threadIdx.y) * x128::size; + for(int i = sidx; i < C; i += blockDim.y * WARP_SIZE * x128::size) { + s_weight[i/x128::size] = load128(weight + i); + } + __syncthreads(); + + int idx = blockIdx.x * blockDim.y + threadIdx.y; + if(idx > N) return; + + // adjust pointers to current token + residual += C * idx; + normed += C * idx; + inp1 += C * idx; + inp2 += C * idx; + + const float eps = 1e-5f; + for(int c = threadIdx.x * x128::size; c < C; c += WARP_SIZE * x128::size) { + const x128 in1 = load128cs(inp1 + c); + const x128 in2 = load128cs(inp2 + c); + x128 out; + for(int k = 0; k < x128::size; ++k) { + out[k] = (float)in1[k] + (float)in2[k]; + } + store128cs(residual + c, out); + s_res[c / x128::size] = out; + } + + float v = 0.f; + + for(int c = threadIdx.x * x128::size; c < C; c += WARP_SIZE * x128::size) { + const x128 res = s_res[c / x128::size]; + for(int k = 0; k < x128::size; ++k) { + v += (float)res[k] * (float)res[k]; + } + } + + v = warpReduceSum(v) / C; + float s = rsqrtf(v + eps); + + for(int c = threadIdx.x * x128::size; c < C; c += WARP_SIZE * x128::size) { + const x128 res = s_res[c / x128::size]; + const x128 w = s_weight[c / x128::size]; + x128 out; + for(int k = 0; k < x128::size; ++k) { + float n = s * (float)res[k]; // normalized output + float o = n * (float)w[k]; // scale + out[k] = o; + } + + store128cs(normed + c, out); + } + // cache the rrms for the backward pass later + if(threadIdx.x == 0) { + rrms[idx] = s; + } +} + __global__ void residual_forward_kernel(floatX* out, const floatX* inp1, const floatX* inp2) { int idx = (blockIdx.x * blockDim.x + threadIdx.x) * x128::size; @@ -549,6 +620,30 @@ void fused_residual_forward5(floatX* residual, floatX* normed, float* mean, floa cudaCheck(cudaGetLastError()); } +void fused_residual_rmsnorm_forward5(floatX* residual, floatX* normed, float* rrms, + const floatX* inp1, const floatX* inp2, + const floatX* weight, + int N, int C, cudaStream_t stream) { + const int block_size = 256; + int block_y = block_size / WARP_SIZE; + const int grid_size = CEIL_DIV(N, block_y); + size_t smem = (1 + block_y) * C * sizeof(floatX); + + // in order to use more than 48 KiB of smem, need to call cudaFuncSetAttribute + // this may fail, in which case we fall back to the smem free implementation. + cudaCheck(cudaGetLastError()); + auto status = cudaFuncSetAttribute(fused_residual_rmsnorm_forward_kernel5, cudaFuncAttributeMaxDynamicSharedMemorySize, smem); + cudaCheck(cudaGetLastError()); + if(status == cudaSuccess) { + fused_residual_rmsnorm_forward_kernel5<<>>(residual, normed, + rrms, inp1, inp2, + weight, N, C); + } else { + assert(false); + } + cudaCheck(cudaGetLastError()); +} + void layernorm_backward(floatX* dinp, floatX* dweight, floatX* dbias, float* scratch, const floatX* dout, const floatX* inp, const floatX* weight, const float* mean, const float* rstd, int B, int T, int C, cudaStream_t stream) { From 52c7254267bc7eda10998f043989fa908e0d543d Mon Sep 17 00:00:00 2001 From: Andrej Karpathy Date: Wed, 25 Sep 2024 19:00:27 +0000 Subject: [PATCH 22/63] add the finished RoPE forward pass --- dev/cuda/Makefile | 3 +- dev/cuda/repkv.cu | 1 + dev/cuda/rope.cu | 192 ++++++++++++++++++++++++++++++++++++++++++++++ llmc/rope.cuh | 83 ++++++++++++++++++++ train_llama3.cu | 51 ++++++++---- train_llama3.py | 20 +++-- 6 files changed, 326 insertions(+), 24 deletions(-) create mode 100644 dev/cuda/rope.cu create mode 100644 llmc/rope.cuh diff --git a/dev/cuda/Makefile b/dev/cuda/Makefile index 5e2fada82..9940b1669 100644 --- a/dev/cuda/Makefile +++ b/dev/cuda/Makefile @@ -30,7 +30,7 @@ MPI_PATHS = -I/usr/lib/x86_64-linux-gnu/openmpi/include -L/usr/lib/x86_64-linux- $(NVCC) $(CFLAGS) $(NVCCFLAGS) $< -o $@ # Build all targets -TARGETS = adamw attention_backward attention_forward classifier_fused crossentropy_forward crossentropy_softmax_backward encoder_backward encoder_forward gelu_backward gelu_forward layernorm_backward layernorm_forward matmul_backward matmul_backward_bias matmul_forward nccl_all_reduce residual_forward softmax_forward trimat_forward fused_residual_forward global_norm permute repkv +TARGETS = adamw attention_backward attention_forward classifier_fused crossentropy_forward crossentropy_softmax_backward encoder_backward encoder_forward gelu_backward gelu_forward layernorm_backward layernorm_forward matmul_backward matmul_backward_bias matmul_forward nccl_all_reduce residual_forward softmax_forward trimat_forward fused_residual_forward global_norm permute repkv rope all: $(TARGETS) all_ptx: $(TARGETS:%=%.ptx) @@ -67,6 +67,7 @@ global_norm: global_norm.cu permute: permute.cu repkv: repkv.cu +rope: rope.cu # NCCL communication kernels nccl_all_reduce: nccl_all_reduce.cu diff --git a/dev/cuda/repkv.cu b/dev/cuda/repkv.cu index aa816294c..264faaf56 100644 --- a/dev/cuda/repkv.cu +++ b/dev/cuda/repkv.cu @@ -28,6 +28,7 @@ void repkv_forward_cpu(float* out, const float* inp, int B, int T, int C, int hd, int qh, int kh, int vh) { // inp is (B, T, C) + // out is (B, T, 3, NH, HD) // hd = head dimension // qh, kh, vh = number of query, key, value heads assert(C == hd * (qh + kh + vh)); diff --git a/dev/cuda/rope.cu b/dev/cuda/rope.cu new file mode 100644 index 000000000..4e45a4711 --- /dev/null +++ b/dev/cuda/rope.cu @@ -0,0 +1,192 @@ +/* +CUDA kernels for RoPE. + +Compile and run as: +make rope +./rope + +The fastest block size is 128 on H100. +*/ + +#include +#include +#include +#include +#include "common.h" + +void precompute_freqs_cis(float *freqs_cis, int dim, int end, float theta, int use_scaled) { + // same as precompute_freqs_cis_real in rope.py + for (int i = 0; i < dim / 2; i++) { + + // calculate the frequency for the (i, i+1)th dimension + float freq = 1.0f / powf(theta, (float)(2 * i) / dim); + if (use_scaled) { + const int scale_factor = 8; + const int low_freq_factor = 1; + const int high_freq_factor = 4; + const int old_context_len = 8192; // original llama3 length + const float low_freq_wavelen = (float)old_context_len / low_freq_factor; + const float high_freq_wavelen = (float)old_context_len / high_freq_factor; + float wavelen = 2.0f * M_PI / freq; + if (wavelen < high_freq_wavelen) { + // skip; keep freq as is + } else if (wavelen > low_freq_wavelen) { + // scale down by scale_factor + freq /= scale_factor; + } else { + // smooth transition between scaled and unscaled + float smooth = ((float)old_context_len / wavelen - low_freq_factor) / (high_freq_factor - low_freq_factor); + freq = (1.0f - smooth) * freq / scale_factor + smooth * freq; + } + } + + // iterate over all time steps, calculate the angle, and store the cos/sin + for (int t = 0; t < end; t++) { + float angle = (float)t * freq; + freqs_cis[t * dim + 2 * i] = cosf(angle); // real part + freqs_cis[t * dim + 2 * i + 1] = sinf(angle); // imaginary part + } + } +} + +void apply_rotary_emb_forward(float *out, const float *inp, const float *freqs_cis, int B, int T, int n_head, int head_dim) { + // same as apply_rotary_emb_real in rope.py + for (int b = 0; b < B; b++) { + for (int t = 0; t < T; t++) { + int idx_bt = b * (T * n_head * head_dim) + t * (n_head * head_dim); + for (int h = 0; h < n_head; h++) { + int idx_bth = idx_bt + h * head_dim; + for (int d = 0; d < head_dim / 2; d++) { + // fetch a tuple of activations, which we imagine as a complex number + int idx = idx_bth + 2 * d; + float x_real = inp[idx]; + float x_imag = inp[idx + 1]; + // fetch the angle from freqs_cis + int freqs_idx = t * head_dim + 2 * d; + float freqs_cos = freqs_cis[freqs_idx]; + float freqs_sin = freqs_cis[freqs_idx + 1]; + // apply the rotation + out[idx] = x_real * freqs_cos - x_imag * freqs_sin; + out[idx + 1] = x_real * freqs_sin + x_imag * freqs_cos; + } + } + } + } +} + +// kernel +__global__ void rope_forward_kernel1(floatX *out, const floatX *inp, const floatX *freqs_cis, int B, int T, int n_head, int head_dim) { + int idx = blockIdx.x * blockDim.x + threadIdx.x; + int head_dim_half = head_dim / 2; + if (idx >= B * T * n_head * head_dim_half) return; + // decode the individual indices + int b = idx / (T * n_head * head_dim_half); + int t = (idx / (n_head * head_dim_half)) % T; + int h = (idx / head_dim_half) % n_head; + int d = idx % head_dim_half; + // calculate the index in the input + int idx_bt = b * (T * n_head * head_dim) + t * (n_head * head_dim); + int idx_bth = idx_bt + h * head_dim; + int idxi = idx_bth + 2 * d; // index in the input + // fetch the input + float x_real = inp[idxi]; + float x_imag = inp[idxi + 1]; + // fetch the freqs_cis + int freqs_idx = t * head_dim + 2 * d; + float freqs_cos = freqs_cis[freqs_idx]; + float freqs_sin = freqs_cis[freqs_idx + 1]; + // apply the rotation + out[idxi] = x_real * freqs_cos - x_imag * freqs_sin; + out[idxi + 1] = x_real * freqs_sin + x_imag * freqs_cos; +} + +// launchers +void rope_forward1(floatX *out, const floatX *inp, const floatX *freqs_cis, int B, int T, int n_head, int head_dim, int block_size) { + // let's launch one thread per element of the output (but divide two!) because the work is in "tuples" + int total_threads = B * T * n_head * head_dim / 2; + int num_blocks = ceil_div(total_threads, block_size); + rope_forward_kernel1<<>>(out, inp, freqs_cis, B, T, n_head, head_dim); + cudaCheck(cudaGetLastError()); +} + +void rope_forward(int kernel_num, floatX *out, const floatX *inp, const floatX *freqs_cis, + int B, int T, int n_head, int head_dim, + int block_size) { + switch (kernel_num) { + case 1: + rope_forward1(out, inp, freqs_cis, B, T, n_head, head_dim, block_size); + break; + default: + printf("Invalid kernel number\n"); + exit(1); + } +} + +// tester +int main(int argc, char **argv) { + srand(0); + + int B = 8; + int T = 1024; + int n_head = 32; + int head_dim = 128; + + int deviceIdx = 0; + cudaCheck(cudaSetDevice(deviceIdx)); + + // do the CPU reference calculation + float *inp = make_random_float(B * T * n_head * head_dim); + float *freqs_cis = (float *)malloc(T * head_dim * sizeof(float)); + precompute_freqs_cis(freqs_cis, head_dim, T, 10000, 1); + float *out = (float *)malloc(B * T * n_head * head_dim * sizeof(float)); + apply_rotary_emb_forward(out, inp, freqs_cis, B, T, n_head, head_dim); + + // allocate GPU memory + float *d_inp; + float *d_freqs_cis; + float *d_out; + cudaCheck(cudaMalloc(&d_inp, B * T * n_head * head_dim * sizeof(float))); + cudaCheck(cudaMalloc(&d_freqs_cis, T * head_dim * sizeof(float))); + cudaCheck(cudaMalloc(&d_out, B * T * n_head * head_dim * sizeof(float))); + + // copy data to GPU + cudaCheck(cudaMemcpy(d_inp, inp, B * T * n_head * head_dim * sizeof(float), cudaMemcpyHostToDevice)); + cudaCheck(cudaMemcpy(d_freqs_cis, freqs_cis, T * head_dim * sizeof(float), cudaMemcpyHostToDevice)); + + // read kernel_num from command line + int kernel_num = 1; + if (argc > 1) { + kernel_num = atoi(argv[1]); + } + printf("Using kernel %d\n", kernel_num); + + // check the correctness of the kernel at all block sizes + int block_sizes[] = {32, 64, 128, 256, 512, 1024}; + for (int j = 0; j < sizeof(block_sizes) / sizeof(int); j++) { + int block_size = block_sizes[j]; + printf("Checking block size %d.\n", block_size); + rope_forward(kernel_num, d_out, d_inp, d_freqs_cis, B, T, n_head, head_dim, block_size); + validate_result(d_out, out, "out", B * T * n_head * head_dim, 1e-5f); + } + printf("All results match. Starting benchmarks.\n\n"); + + // now benchmark + for (int j = 0; j < sizeof(block_sizes) / sizeof(int); j++) { + int block_size = block_sizes[j]; + int repeat_times = 1000; + float elapsed_time = benchmark_kernel(repeat_times, rope_forward, kernel_num, + d_out, d_inp, d_freqs_cis, B, T, n_head, head_dim, block_size); + printf("block_size %4d time %.4f ms\n", block_size, elapsed_time); + } + + // free memory + free(inp); + free(freqs_cis); + free(out); + cudaCheck(cudaFree(d_inp)); + cudaCheck(cudaFree(d_freqs_cis)); + cudaCheck(cudaFree(d_out)); + return 0; +} + + diff --git a/llmc/rope.cuh b/llmc/rope.cuh new file mode 100644 index 000000000..1d2dce5fa --- /dev/null +++ b/llmc/rope.cuh @@ -0,0 +1,83 @@ +/* +Implements the RoPE rotation for the attention mechanism. + +See dev/cuda/rope.cu for correctness and performance reference +block_size 128 seems fastest on H100 +*/ + +#include "cuda_common.h" + +void precompute_freqs_cis(floatX *freqs_cis, int dim, int end, float theta, int use_scaled) { + // helper function that (on the CPU!) precomputes the freqs_cis for the RoPE rotation + // same as precompute_freqs_cis_real in rope.py + for (int i = 0; i < dim / 2; i++) { + + // calculate the frequency for the (i, i+1)th dimension + float freq = 1.0f / powf(theta, (float)(2 * i) / dim); + if (use_scaled) { + const int scale_factor = 8; + const int low_freq_factor = 1; + const int high_freq_factor = 4; + const int old_context_len = 8192; // original llama3 length + const float low_freq_wavelen = (float)old_context_len / low_freq_factor; + const float high_freq_wavelen = (float)old_context_len / high_freq_factor; + float wavelen = 2.0f * M_PI / freq; + if (wavelen < high_freq_wavelen) { + // skip; keep freq as is + } else if (wavelen > low_freq_wavelen) { + // scale down by scale_factor + freq /= scale_factor; + } else { + // smooth transition between scaled and unscaled + float smooth = ((float)old_context_len / wavelen - low_freq_factor) / (high_freq_factor - low_freq_factor); + freq = (1.0f - smooth) * freq / scale_factor + smooth * freq; + } + } + + // iterate over all time steps, calculate the angle, and store the cos/sin + for (int t = 0; t < end; t++) { + float angle = (float)t * freq; + freqs_cis[t * dim + 2 * i] = cosf(angle); // real part + freqs_cis[t * dim + 2 * i + 1] = sinf(angle); // imaginary part + } + } +} + +__global__ void rope_forward_kernel1(floatX *out, const floatX *inp, const floatX *freqs_cis, int B, int T, int n_head, int head_dim) { + int idx = blockIdx.x * blockDim.x + threadIdx.x; + int head_dim_half = head_dim / 2; + if (idx >= B * T * 3 * n_head * head_dim_half) return; + // decode the qkv index early so we can early exit if it's a value index + int qkv = (idx / (n_head * head_dim_half)) % 3; + if (qkv == 2) return; // no-op for v + // decode the individual indices and get the input index + int b = idx / (T * 3 * n_head * head_dim_half); + int t = (idx / (3 * n_head * head_dim_half)) % T; + int h = (idx / head_dim_half) % n_head; + int d = idx % head_dim_half; + int idx_bt = b * (T * 3 * n_head * head_dim) + t * (3 * n_head * head_dim); + int idx_bth = idx_bt + qkv * (n_head * head_dim) + h * head_dim; + int idxi = idx_bth + 2 * d; // index in the input + // fetch the input + float x_real = inp[idxi]; + float x_imag = inp[idxi + 1]; + // fetch the freqs_cis + int freqs_idx = t * head_dim + 2 * d; + float freqs_cos = freqs_cis[freqs_idx]; + float freqs_sin = freqs_cis[freqs_idx + 1]; + // apply the rotation + out[idxi] = x_real * freqs_cos - x_imag * freqs_sin; + out[idxi + 1] = x_real * freqs_sin + x_imag * freqs_cos; +} + +void rope_forward(floatX *out, const floatX *inp, const floatX *freqs_cis, int B, int T, int n_head, int head_dim) { + // the input and output to this kernel are (B, T, 3, NH, HD) where the 3 is q,k,v + // we are going to launch exactly one thread per element of the output, + // except divide by two because the work is in "tuples" + // so this single kernel launch will do RoPE for both q and k, and the threads for v will be a no-op + const int block_size = 128; + int total_threads = B * T * 3 * n_head * head_dim / 2; + int num_blocks = CEIL_DIV(total_threads, block_size); + rope_forward_kernel1<<>>(out, inp, freqs_cis, B, T, n_head, head_dim); + cudaCheck(cudaGetLastError()); +} diff --git a/train_llama3.cu b/train_llama3.cu index a83561941..c4d8cb691 100644 --- a/train_llama3.cu +++ b/train_llama3.cu @@ -65,6 +65,8 @@ GPT-2 Transformer Neural Net training loop. See README.md for usage. #include "llmc/global_norm.cuh" // defines: repkv_forward #include "llmc/repkv.cuh" +// defines: precompute_freqs_cis, rope_forward +#include "llmc/rope.cuh" // ----------- Multi-GPU support ----------- // defines: ncclFloatX, ncclCheck, MultiGpuConfig, ShardInfo // defines: printf0, multi_gpu_config @@ -344,6 +346,7 @@ typedef struct { // todo - if other functions need cpu scratch buffers in the future, reuse as generic scratch? int* workload_indices; // encoder_backward, B*T*num_c_groups (int) int4* bucket_info; // encoder_backward, B*T*num_c_groups (int4) - size for worst case + floatX* freqs_cis; // (T, hd) for RoPE } GPT2; void gpt2_init_common(GPT2 *model) { @@ -373,6 +376,7 @@ void gpt2_init_common(GPT2 *model) { model->init_state = true; model->recompute = 1; // good default: recompute gelu but not layernorm model->gelu_fusion = 0; //deviceProp.major >= 9 ? 2 : 0; // default: off for now (default must match main()) + model->freqs_cis = NULL; } void gpt2_allocate_weights(GPT2 *model) { @@ -423,6 +427,15 @@ void gpt2_allocate_state(GPT2 *model, int B, int T) { model->workload_indices = (int*)mallocCheck(sizeof(int) * model->batch_size * model->seq_len * num_c_groups); model->bucket_info = (int4*)mallocCheck(sizeof(int4) * model->batch_size * model->seq_len * num_c_groups); + // precompute freqs_cis for RoPE + int hd = model->config.channels / model->config.num_heads; + printf("calculating and allocating %zu KiB for freqs_cis\n", (T * hd * sizeof(floatX)) >> 10); + floatX* freqs_cis_cpu = (floatX*)mallocCheck(T * hd * sizeof(floatX)); + precompute_freqs_cis(freqs_cis_cpu, hd, T, model->config.rope_theta, model->config.use_scaled_rope); + cudaCheck(cudaMalloc((void**)&model->freqs_cis, T * hd * sizeof(floatX))); + cudaCheck(cudaMemcpy(model->freqs_cis, freqs_cis_cpu, T * hd * sizeof(floatX), cudaMemcpyHostToDevice)); + free(freqs_cis_cpu); + // cudaMallocConditionallyManaged can fall back to cudaMallocManaged if not enough memory on device // and returns a status code of 1 if it had to fall back, in that case we want to print warning. int memory_status = 0; @@ -677,25 +690,33 @@ void gpt2_forward(GPT2 *model, const int* inputs, size_t B, size_t T) { matmul_forward_cublaslt(scratch, l_ln1, l_qkvw, l_qkvb, B, T, C, qkv_channels, main_stream); // 2) replicate k,v so that all of q,k,v have the same number of heads. done for simplicity, for now repkv_forward(qkv_rep_scratch, scratch, B, T, n_head, n_kv_head, hd); - - // ------------------------------------------------------------------------ - // DEBUGGING: we only work until this point right now, so exit here - // transfer the first 32 elements to CPU and print them - floatX* output = qkv_rep_scratch; - floatX* cpu = (floatX*)mallocCheck(32 * sizeof(floatX)); - cudaCheck(cudaMemcpy(cpu, output, 32 * sizeof(floatX), cudaMemcpyDeviceToHost)); - for (int i = 0; i < 32; i++) { - printf("q[%d] = %.8f\n", i, (float) cpu[i]); - } - exit(0); - // ------------------------------------------------------------------------ - - // 3) apply RoPE to q,k + // 3) apply RoPE to q,k in place + rope_forward(qkv_rep_scratch, qkv_rep_scratch, model->freqs_cis, B, T, n_head, hd); // 4) attention: att <- softmax(qk^T)v - attention_forward(l_atty, l_qkvr, l_att, scratch, B, T, C, NH, main_stream); + attention_forward(l_atty, l_qkvr, l_att, qkv_rep_scratch, B, T, C, NH, main_stream); #endif matmul_forward_cublaslt(scratch, l_atty, l_attprojw, l_attprojb, B, T, C, C, main_stream); + + // ------------------------------------------------------------------------ + // DEBUGGING: we only work until this point right now, so exit here + // transfer the first 32 elements to CPU and print them + floatX* output = scratch; + floatX* cpu = (floatX*)mallocCheck(32 * sizeof(floatX)); + cudaCheck(cudaMemcpy(cpu, output, 32 * sizeof(floatX), cudaMemcpyDeviceToHost)); + for (int i = 0; i < 32; i++) { + printf("q[%d] = %.8f\n", i, (float) cpu[i]); + } + // write to .bin file + // move output to cpu + floatX* cpu_output = (floatX*)mallocCheck(B*T*C * sizeof(floatX)); + cudaCheck(cudaMemcpy(cpu_output, output, B*T*C * sizeof(floatX), cudaMemcpyDeviceToHost)); + FILE* f = fopen("out.bin", "wb"); + fwrite(cpu_output, sizeof(floatX), B*T*C, f); + fclose(f); + exit(0); + // ------------------------------------------------------------------------ + fused_residual_forward5(l_residual2, l_ln2, l_ln2_mean, l_ln2_rstd, residual, scratch, l_ln2w, l_ln2b, B*T, C, main_stream); matmul_forward_cublaslt(l_fch_gelu, l_ln2, l_fcw, l_fcb, B, T, C, 4*C, main_stream, l_fch, model->gelu_fusion); matmul_forward_cublaslt(scratch, l_fch_gelu, l_fcprojw, l_fcprojb, B, T, 4*C, C, main_stream); diff --git a/train_llama3.py b/train_llama3.py index 08a4f7fb0..30945df63 100644 --- a/train_llama3.py +++ b/train_llama3.py @@ -166,14 +166,6 @@ def __init__(self, config): def forward(self, x, freqs_cis=None, start_pos=None, mask=None): B, T, C = x.size() # batch size, sequence length, embedding dimensionality (n_embd) - - # --------------------------------------------------------------------- - # DEBUGGING: print first 32 elements of x - for i in range(32): - print("acts[{}]: {:.8f}".format(i, x.view(-1)[i].item())) - breakpoint() - # --------------------------------------------------------------------- - # calculate query, key, values for all heads in batch and move head forward to be the batch dim qkv = self.c_attn(x) q, k, v = qkv.split([self.n_head * self.hd, self.n_kv_head * self.hd, self.n_kv_head * self.hd], dim=-1) @@ -206,6 +198,18 @@ def forward(self, x, freqs_cis=None, start_pos=None, mask=None): y = att @ v # (B, NH, T, T) x (B, NH, T, HD) -> (B, NH, T, HD) y = y.transpose(1, 2).contiguous().view(B, T, C) y = self.c_proj(y) + + # --------------------------------------------------------------------- + # DEBUGGING: print first 32 elements of x + x = y.contiguous() + for i in range(32): + print("q[{}]: {:.8f}".format(i, x.view(-1)[i].item())) + # write to .bin file + with open("ref.bin", "wb") as f: + f.write(x.view(-1).cpu().detach().numpy().tobytes()) + breakpoint() + # --------------------------------------------------------------------- + return y class MLP(nn.Module): From bb3c92da5550b62e33d224ccdc6dc0afaf69b217 Mon Sep 17 00:00:00 2001 From: Andrej Karpathy Date: Wed, 25 Sep 2024 19:12:18 +0000 Subject: [PATCH 23/63] integrate the fused rmsnorm forward --- train_llama3.cu | 8 +++----- train_llama3.py | 23 +++++++++++------------ 2 files changed, 14 insertions(+), 17 deletions(-) diff --git a/train_llama3.cu b/train_llama3.cu index c4d8cb691..2afd3593d 100644 --- a/train_llama3.cu +++ b/train_llama3.cu @@ -46,7 +46,7 @@ GPT-2 Transformer Neural Net training loop. See README.md for usage. // defines: encoder_forward, encoder_backward #include "llmc/encoder.cuh" // defines: layernorm_forward, residual_forward, fused_residual_forward5, layernorm_backward -// defines: rmsnorm_forward +// defines: rmsnorm_forward, fused_residual_rmsnorm_forward5 #include "llmc/layernorm.cuh" // defines: matmul_cublaslt, matmul_forward, matmul_backward, gelu_forward, gelu_backward_inplace #include "llmc/matmul.cuh" @@ -653,7 +653,6 @@ void gpt2_forward(GPT2 *model, const int* inputs, size_t B, size_t T) { floatX* l_attprojw = params.attprojw + l * C * C; floatX* l_attprojb = params.attprojb + l * C; floatX* l_ln2w = params.ln2w + l * C; - floatX* l_ln2b = params.ln2b + l * C; floatX* l_fcw = params.fcw + l * 4*C * C; floatX* l_fcb = params.fcb + l * 4*C; floatX* l_fcprojw = params.fcprojw + l * C * 4*C; @@ -665,7 +664,6 @@ void gpt2_forward(GPT2 *model, const int* inputs, size_t B, size_t T) { floatX* l_atty = acts.atty + l * B * T * C; floatX* l_residual2 = acts.residual2 + l * B * T * C; floatX* l_ln2 = (model->recompute < 2) ? acts.ln2 + l * B * T * C : acts.lnf; - float* l_ln2_mean = acts.ln2_mean + l * B * T; float* l_ln2_rstd = acts.ln2_rstd + l * B * T; floatX* l_fch = acts.fch + l * B * T * 4*C; // reuse the same activation buffer at each layer, as we'll re-compute the gelu during backward @@ -697,11 +695,12 @@ void gpt2_forward(GPT2 *model, const int* inputs, size_t B, size_t T) { #endif matmul_forward_cublaslt(scratch, l_atty, l_attprojw, l_attprojb, B, T, C, C, main_stream); + fused_residual_rmsnorm_forward5(l_residual2, l_ln2, l_ln2_rstd, residual, scratch, l_ln2w, B*T, C, main_stream); // ------------------------------------------------------------------------ // DEBUGGING: we only work until this point right now, so exit here // transfer the first 32 elements to CPU and print them - floatX* output = scratch; + floatX* output = l_ln2; floatX* cpu = (floatX*)mallocCheck(32 * sizeof(floatX)); cudaCheck(cudaMemcpy(cpu, output, 32 * sizeof(floatX), cudaMemcpyDeviceToHost)); for (int i = 0; i < 32; i++) { @@ -717,7 +716,6 @@ void gpt2_forward(GPT2 *model, const int* inputs, size_t B, size_t T) { exit(0); // ------------------------------------------------------------------------ - fused_residual_forward5(l_residual2, l_ln2, l_ln2_mean, l_ln2_rstd, residual, scratch, l_ln2w, l_ln2b, B*T, C, main_stream); matmul_forward_cublaslt(l_fch_gelu, l_ln2, l_fcw, l_fcb, B, T, C, 4*C, main_stream, l_fch, model->gelu_fusion); matmul_forward_cublaslt(scratch, l_fch_gelu, l_fcprojw, l_fcprojb, B, T, 4*C, C, main_stream); // OK, fusion across blocks. diff --git a/train_llama3.py b/train_llama3.py index 30945df63..9e15e0c1c 100644 --- a/train_llama3.py +++ b/train_llama3.py @@ -198,18 +198,6 @@ def forward(self, x, freqs_cis=None, start_pos=None, mask=None): y = att @ v # (B, NH, T, T) x (B, NH, T, HD) -> (B, NH, T, HD) y = y.transpose(1, 2).contiguous().view(B, T, C) y = self.c_proj(y) - - # --------------------------------------------------------------------- - # DEBUGGING: print first 32 elements of x - x = y.contiguous() - for i in range(32): - print("q[{}]: {:.8f}".format(i, x.view(-1)[i].item())) - # write to .bin file - with open("ref.bin", "wb") as f: - f.write(x.view(-1).cpu().detach().numpy().tobytes()) - breakpoint() - # --------------------------------------------------------------------- - return y class MLP(nn.Module): @@ -228,6 +216,17 @@ def __init__(self, config): def forward(self, x): # SwiGLU self.c_proj(F.silu(self.c_fc2(x)) * self.c_fc(x)) <-- 3. difference compared to GPT-2 + + # --------------------------------------------------------------------- + # DEBUGGING: print first 32 elements of x + for i in range(32): + print("q[{}]: {:.8f}".format(i, x.view(-1)[i].item())) + # write to .bin file + with open("ref.bin", "wb") as f: + f.write(x.view(-1).cpu().detach().numpy().tobytes()) + breakpoint() + # --------------------------------------------------------------------- + x1 = self.c_fc(x) x2 = self.c_fc2(x) x2 = F.silu(x2) From 1826752ae14067ec3574168e22e53eb0936205a8 Mon Sep 17 00:00:00 2001 From: Andrej Karpathy Date: Wed, 25 Sep 2024 22:16:45 +0000 Subject: [PATCH 24/63] add swigul yaygit add -u! --- llmc/repkv.cuh | 4 ++-- llmc/rope.cuh | 4 ++-- train_llama3.cu | 40 ++++++++++++++++++++++++++++------------ train_llama3.py | 10 +++++----- 4 files changed, 37 insertions(+), 21 deletions(-) diff --git a/llmc/repkv.cuh b/llmc/repkv.cuh index 055f15ea7..666ad8c44 100644 --- a/llmc/repkv.cuh +++ b/llmc/repkv.cuh @@ -48,14 +48,14 @@ __global__ void repkv_forward_kernel1(floatX* replicated_qkv, replicated_qkv[idx_flat] = __ldcs(&gqa_qkv[inp_idx]); } -void repkv_forward(floatX* out, const floatX* inp, int B, int T, int NH, int NH_KV, int HD) { +void repkv_forward(floatX* out, const floatX* inp, int B, int T, int NH, int NH_KV, int HD, cudaStream_t stream) { // NH = number of query heads, NH_KV = number of key and value heads, HD = head dimension const int block_size = 128; int total_threads = B * T * (3 * NH) * HD; // one thread per output element int num_blocks = CEIL_DIV(total_threads, block_size); int replicate_factor = NH / NH_KV; if (replicate_factor > 1) { - repkv_forward_kernel1<<>>(out, inp, B, T, NH, replicate_factor, HD); + repkv_forward_kernel1<<>>(out, inp, B, T, NH, replicate_factor, HD); } else { cudaMemcpy(out, inp, total_threads * sizeof(floatX), cudaMemcpyDeviceToDevice); } diff --git a/llmc/rope.cuh b/llmc/rope.cuh index 1d2dce5fa..ca5fc56f9 100644 --- a/llmc/rope.cuh +++ b/llmc/rope.cuh @@ -70,7 +70,7 @@ __global__ void rope_forward_kernel1(floatX *out, const floatX *inp, const float out[idxi + 1] = x_real * freqs_sin + x_imag * freqs_cos; } -void rope_forward(floatX *out, const floatX *inp, const floatX *freqs_cis, int B, int T, int n_head, int head_dim) { +void rope_forward(floatX *out, const floatX *inp, const floatX *freqs_cis, int B, int T, int n_head, int head_dim, cudaStream_t stream) { // the input and output to this kernel are (B, T, 3, NH, HD) where the 3 is q,k,v // we are going to launch exactly one thread per element of the output, // except divide by two because the work is in "tuples" @@ -78,6 +78,6 @@ void rope_forward(floatX *out, const floatX *inp, const floatX *freqs_cis, int B const int block_size = 128; int total_threads = B * T * 3 * n_head * head_dim / 2; int num_blocks = CEIL_DIV(total_threads, block_size); - rope_forward_kernel1<<>>(out, inp, freqs_cis, B, T, n_head, head_dim); + rope_forward_kernel1<<>>(out, inp, freqs_cis, B, T, n_head, head_dim); cudaCheck(cudaGetLastError()); } diff --git a/train_llama3.cu b/train_llama3.cu index 2afd3593d..506512b32 100644 --- a/train_llama3.cu +++ b/train_llama3.cu @@ -67,6 +67,8 @@ GPT-2 Transformer Neural Net training loop. See README.md for usage. #include "llmc/repkv.cuh" // defines: precompute_freqs_cis, rope_forward #include "llmc/rope.cuh" +// defines: swiglu_forward +#include "llmc/swiglu.cuh" // ----------- Multi-GPU support ----------- // defines: ncclFloatX, ncclCheck, MultiGpuConfig, ShardInfo // defines: printf0, multi_gpu_config @@ -252,6 +254,13 @@ void fill_in_activation_sizes(const ActivationTensors* data, TensorSpec (&tensor const size_t n_kv_head = config.num_kv_heads; // num key and value heads const size_t hd = C / n_head; // the size of each head const size_t qkv_channels = (n_head + 2*n_kv_head) * hd; // Q, K, V channels + // SwiGLU-related calculation to determine the number of channels here + size_t hidden_dim = 4 * C; + hidden_dim = (2 * hidden_dim) / 3; + hidden_dim = config.ffn_dim_multiplier * hidden_dim; + hidden_dim = config.multiple_of * ((hidden_dim + config.multiple_of - 1) / config.multiple_of); + size_t ffn_channels = hidden_dim * 2; // c_fc + c_fc2 concatenated + size_t ffn_channels_post_gelu = hidden_dim; // swiglu will halve the channels tensors[0] = TENSOR_SPEC(data->encoded, B * T * C); // if recompute >= 1 then we will recompute the layernorm forward activation during backward pass @@ -270,9 +279,9 @@ void fill_in_activation_sizes(const ActivationTensors* data, TensorSpec (&tensor tensors[7] = TENSOR_SPEC(data->ln2, (recompute < 2) ? L * B * T * C : 0); tensors[8] = TENSOR_SPEC(data->ln2_mean, L * B * T); tensors[9] = TENSOR_SPEC(data->ln2_rstd, L * B * T); - tensors[10] = TENSOR_SPEC(data->fch, L * B * T * 4*C); + tensors[10] = TENSOR_SPEC(data->fch, L * B * T * ffn_channels); // if recompute >= 1 then we will recompute gelu_forward during backward and use this as scratch buffer - tensors[11] = TENSOR_SPEC(data->fch_gelu, (recompute < 1) ? L * B * T * 4*C : B * T * 4*C); + tensors[11] = TENSOR_SPEC(data->fch_gelu, (recompute < 1) ? L * B * T * ffn_channels_post_gelu : B * T * ffn_channels_post_gelu); tensors[12] = TENSOR_SPEC(data->residual3, L * B * T * C); tensors[13] = TENSOR_SPEC(data->lnf, B * T * C); tensors[14] = TENSOR_SPEC(data->lnf_mean, B * T); @@ -621,6 +630,12 @@ void gpt2_forward(GPT2 *model, const int* inputs, size_t B, size_t T) { const size_t n_kv_head = model->config.num_kv_heads; const size_t hd = C / n_head; // head dimension const size_t qkv_channels = (n_head + 2*n_kv_head) * hd; // Q, K, V channels + size_t hidden_dim = 4 * C; + hidden_dim = (2 * hidden_dim) / 3; + hidden_dim = model->config.ffn_dim_multiplier * hidden_dim; + hidden_dim = model->config.multiple_of * ((hidden_dim + model->config.multiple_of - 1) / model->config.multiple_of); + size_t ffn_channels = hidden_dim * 2; // c_fc + c_fc2 concatenated + size_t ffn_channels_post_gelu = hidden_dim; // swiglu halves the channels // validate B,T are not larger than the values used at initialisation // (smaller B,T are okay for inference only) @@ -653,9 +668,9 @@ void gpt2_forward(GPT2 *model, const int* inputs, size_t B, size_t T) { floatX* l_attprojw = params.attprojw + l * C * C; floatX* l_attprojb = params.attprojb + l * C; floatX* l_ln2w = params.ln2w + l * C; - floatX* l_fcw = params.fcw + l * 4*C * C; - floatX* l_fcb = params.fcb + l * 4*C; - floatX* l_fcprojw = params.fcprojw + l * C * 4*C; + floatX* l_fcw = params.fcw + l * ffn_channels * C; + floatX* l_fcb = params.fcb + l * ffn_channels; + floatX* l_fcprojw = params.fcprojw + l * C * ffn_channels_post_gelu; floatX* l_fcprojb = params.fcprojb + l * C; // get the pointers of the activations for this layer @@ -665,10 +680,10 @@ void gpt2_forward(GPT2 *model, const int* inputs, size_t B, size_t T) { floatX* l_residual2 = acts.residual2 + l * B * T * C; floatX* l_ln2 = (model->recompute < 2) ? acts.ln2 + l * B * T * C : acts.lnf; float* l_ln2_rstd = acts.ln2_rstd + l * B * T; - floatX* l_fch = acts.fch + l * B * T * 4*C; + floatX* l_fch = acts.fch + l * B * T * ffn_channels; // reuse the same activation buffer at each layer, as we'll re-compute the gelu during backward // very useful because we dramatically reduce VRAM usage, and may be able to fit larger batch size - floatX* l_fch_gelu = (model->recompute < 1) ? acts.fch_gelu + l * B * T * 4*C : acts.fch_gelu; + floatX* l_fch_gelu = (model->recompute < 1) ? acts.fch_gelu + l * B * T * ffn_channels_post_gelu : acts.fch_gelu; floatX* l_residual3 = acts.residual3 + l * B * T * C; floatX* scratch = (floatX*)acts.output; // used for non-cudnn attention, fcproj, attproj, etc. floatX* qkv_rep_scratch = (floatX*)acts.scratch_bt4c; // we can use the BT4C scratch for qkv replication @@ -687,20 +702,23 @@ void gpt2_forward(GPT2 *model, const int* inputs, size_t B, size_t T) { // 1) projection to QKV vectors (note k,v may be fewer heads than q) matmul_forward_cublaslt(scratch, l_ln1, l_qkvw, l_qkvb, B, T, C, qkv_channels, main_stream); // 2) replicate k,v so that all of q,k,v have the same number of heads. done for simplicity, for now - repkv_forward(qkv_rep_scratch, scratch, B, T, n_head, n_kv_head, hd); + repkv_forward(qkv_rep_scratch, scratch, B, T, n_head, n_kv_head, hd, main_stream); // 3) apply RoPE to q,k in place - rope_forward(qkv_rep_scratch, qkv_rep_scratch, model->freqs_cis, B, T, n_head, hd); + rope_forward(qkv_rep_scratch, qkv_rep_scratch, model->freqs_cis, B, T, n_head, hd, main_stream); // 4) attention: att <- softmax(qk^T)v attention_forward(l_atty, l_qkvr, l_att, qkv_rep_scratch, B, T, C, NH, main_stream); #endif matmul_forward_cublaslt(scratch, l_atty, l_attprojw, l_attprojb, B, T, C, C, main_stream); fused_residual_rmsnorm_forward5(l_residual2, l_ln2, l_ln2_rstd, residual, scratch, l_ln2w, B*T, C, main_stream); + matmul_forward_cublaslt(l_fch, l_ln2, l_fcw, l_fcb, B, T, C, ffn_channels, main_stream); + swiglu_forward(l_fch_gelu, l_fch, B, T, ffn_channels_post_gelu, main_stream); + matmul_forward_cublaslt(scratch, l_fch_gelu, l_fcprojw, l_fcprojb, B, T, ffn_channels_post_gelu, C, main_stream); // ------------------------------------------------------------------------ // DEBUGGING: we only work until this point right now, so exit here // transfer the first 32 elements to CPU and print them - floatX* output = l_ln2; + floatX* output = scratch; floatX* cpu = (floatX*)mallocCheck(32 * sizeof(floatX)); cudaCheck(cudaMemcpy(cpu, output, 32 * sizeof(floatX), cudaMemcpyDeviceToHost)); for (int i = 0; i < 32; i++) { @@ -716,8 +734,6 @@ void gpt2_forward(GPT2 *model, const int* inputs, size_t B, size_t T) { exit(0); // ------------------------------------------------------------------------ - matmul_forward_cublaslt(l_fch_gelu, l_ln2, l_fcw, l_fcb, B, T, C, 4*C, main_stream, l_fch, model->gelu_fusion); - matmul_forward_cublaslt(scratch, l_fch_gelu, l_fcprojw, l_fcprojb, B, T, 4*C, C, main_stream); // OK, fusion across blocks. if(l+1 != L) { floatX* l_ln1 = (model->recompute < 2) ? acts.ln1 + (l + 1) * B * T * C : acts.lnf; diff --git a/train_llama3.py b/train_llama3.py index 9e15e0c1c..87a609008 100644 --- a/train_llama3.py +++ b/train_llama3.py @@ -216,6 +216,11 @@ def __init__(self, config): def forward(self, x): # SwiGLU self.c_proj(F.silu(self.c_fc2(x)) * self.c_fc(x)) <-- 3. difference compared to GPT-2 + x1 = self.c_fc(x) + x2 = self.c_fc2(x) + x2 = F.silu(x2) + x = x1 * x2 + x = self.c_proj(x) # --------------------------------------------------------------------- # DEBUGGING: print first 32 elements of x @@ -227,11 +232,6 @@ def forward(self, x): breakpoint() # --------------------------------------------------------------------- - x1 = self.c_fc(x) - x2 = self.c_fc2(x) - x2 = F.silu(x2) - x = x1 * x2 - x = self.c_proj(x) return x class Block(nn.Module): From 0731b39a7369ca4560f08722a56ae7c228021f85 Mon Sep 17 00:00:00 2001 From: Andrej Karpathy Date: Wed, 25 Sep 2024 22:31:31 +0000 Subject: [PATCH 25/63] forward pass matchesgit add train_llama3.cu train_llama3.py ! losses are the same. now comes the backward pass --- train_llama3.cu | 52 +++++++++++++++++++++++-------------------------- train_llama3.py | 22 ++++++++++----------- 2 files changed, 35 insertions(+), 39 deletions(-) diff --git a/train_llama3.cu b/train_llama3.cu index 506512b32..e515bb185 100644 --- a/train_llama3.cu +++ b/train_llama3.cu @@ -110,7 +110,7 @@ typedef struct { constexpr const int NUM_PARAMETER_TENSORS = 16; typedef struct { floatX* wte; // (V, C) - floatX* wpe; // (maxT, C) + floatX* wpe; // (V, C) floatX* ln1w; // (L, C) floatX* ln1b; // (L, C) floatX* qkvw; // (L, 3*C, C) @@ -715,42 +715,18 @@ void gpt2_forward(GPT2 *model, const int* inputs, size_t B, size_t T) { swiglu_forward(l_fch_gelu, l_fch, B, T, ffn_channels_post_gelu, main_stream); matmul_forward_cublaslt(scratch, l_fch_gelu, l_fcprojw, l_fcprojb, B, T, ffn_channels_post_gelu, C, main_stream); - // ------------------------------------------------------------------------ - // DEBUGGING: we only work until this point right now, so exit here - // transfer the first 32 elements to CPU and print them - floatX* output = scratch; - floatX* cpu = (floatX*)mallocCheck(32 * sizeof(floatX)); - cudaCheck(cudaMemcpy(cpu, output, 32 * sizeof(floatX), cudaMemcpyDeviceToHost)); - for (int i = 0; i < 32; i++) { - printf("q[%d] = %.8f\n", i, (float) cpu[i]); - } - // write to .bin file - // move output to cpu - floatX* cpu_output = (floatX*)mallocCheck(B*T*C * sizeof(floatX)); - cudaCheck(cudaMemcpy(cpu_output, output, B*T*C * sizeof(floatX), cudaMemcpyDeviceToHost)); - FILE* f = fopen("out.bin", "wb"); - fwrite(cpu_output, sizeof(floatX), B*T*C, f); - fclose(f); - exit(0); - // ------------------------------------------------------------------------ - // OK, fusion across blocks. if(l+1 != L) { floatX* l_ln1 = (model->recompute < 2) ? acts.ln1 + (l + 1) * B * T * C : acts.lnf; - float* l_ln1_mean = acts.ln1_mean + (l + 1) * B * T; float* l_ln1_rstd = acts.ln1_rstd + (l + 1) * B * T; const floatX* l_ln1w = params.ln1w + (l + 1) * C; - const floatX* l_ln1b = params.ln1b + (l + 1) * C; - fused_residual_forward5(l_residual3, l_ln1, l_ln1_mean, l_ln1_rstd, l_residual2, scratch, l_ln1w, l_ln1b, - B * T, C, main_stream); + fused_residual_rmsnorm_forward5(l_residual3, l_ln1, l_ln1_rstd, l_residual2, scratch, l_ln1w, B * T, C, main_stream); } else { - fused_residual_forward5(l_residual3, acts.lnf, acts.lnf_mean, acts.lnf_rstd, l_residual2, scratch, - params.lnfw, params.lnfb, - B * T, C, main_stream); + fused_residual_rmsnorm_forward5(l_residual3, acts.lnf, acts.lnf_rstd, l_residual2, scratch, params.lnfw, B * T, C, main_stream); } } - matmul_forward_cublaslt(acts.output, acts.lnf, params.wte, NULL, B, T, C, Vp, main_stream); + matmul_forward_cublaslt(acts.output, acts.lnf, params.wpe, NULL, B, T, C, Vp, main_stream); cudaCheck(cudaDeviceSynchronize()); } @@ -821,6 +797,26 @@ void gpt2_backward_and_reduce(GPT2 *model, int* inputs, const int* targets, int tokenCheck(targets, B*T, V); fused_classifier(acts.output, acts.losses, dloss, model->targets, B, T, V, Vp, True, main_stream); + // ------------------------------------------------------------------------ + // DEBUGGING: we only work until this point right now, so exit here + // transfer the first 32 elements to CPU and print them + float* output = acts.losses; + floatX* cpu = (floatX*)mallocCheck(32 * sizeof(floatX)); + cudaCheck(cudaMemcpy(cpu, output, 32 * sizeof(floatX), cudaMemcpyDeviceToHost)); + for (int i = 0; i < 32; i++) { + printf("q[%d] = %.8f\n", i, (float) cpu[i]); + } + // write to .bin file + // move output to cpu + floatX* cpu_output = (floatX*)mallocCheck(B*T * sizeof(floatX)); + cudaCheck(cudaMemcpy(cpu_output, output, B*T * sizeof(floatX), cudaMemcpyDeviceToHost)); + FILE* f = fopen("out.bin", "wb"); + fwrite(cpu_output, sizeof(floatX), B*T, f); + fclose(f); + exit(0); + // ------------------------------------------------------------------------ + + // ------------------------------------------------------------------------ // backward pass: go in the reverse order of the forward pass, and call backward() functions // reset residual stream gradients (put here to work with gradient accumulation) diff --git a/train_llama3.py b/train_llama3.py index 87a609008..d2b6c0c38 100644 --- a/train_llama3.py +++ b/train_llama3.py @@ -221,17 +221,6 @@ def forward(self, x): x2 = F.silu(x2) x = x1 * x2 x = self.c_proj(x) - - # --------------------------------------------------------------------- - # DEBUGGING: print first 32 elements of x - for i in range(32): - print("q[{}]: {:.8f}".format(i, x.view(-1)[i].item())) - # write to .bin file - with open("ref.bin", "wb") as f: - f.write(x.view(-1).cpu().detach().numpy().tobytes()) - breakpoint() - # --------------------------------------------------------------------- - return x class Block(nn.Module): @@ -323,6 +312,17 @@ def forward(self, idx, targets=None, return_logits=True, start_pos=0): logits = self.lm_head(x[:, [-1], :]).float() # note: using list [-1] to preserve the time dim loss = None + # --------------------------------------------------------------------- + # DEBUGGING: print first 32 elements of x + x = F.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1), ignore_index=-1, reduction='none') + for i in range(32): + print("q[{}]: {:.8f}".format(i, x.view(-1)[i].item())) + # write to .bin file + with open("ref.bin", "wb") as f: + f.write(x.view(-1).cpu().detach().numpy().tobytes()) + breakpoint() + # --------------------------------------------------------------------- + # there are performance reasons why not returning logits is prudent, if not needed if not return_logits: logits = None From d1f2f64541fa7754443741ec2f03db706d5f58e2 Mon Sep 17 00:00:00 2001 From: Insop Song Date: Wed, 25 Sep 2024 17:40:38 -0700 Subject: [PATCH 26/63] Updated repkv_backward cuda kernel - kernel 1 is tested - build ``` make repkv_backward /usr/local/cuda/bin/nvcc -O3 --use_fast_math --generate-code arch=compute_80,code=[compute_80,sm_80] -lcublas -lcublasLt -std=c++17 repkv_backward.cu -o repkv_backward ``` - test run on A30 ``` Using kernel 1 Checking block size 32. 0.531524 0.531524 0.600285 0.600285 0.458787 0.458787 0.296680 0.296680 -0.911627 -0.911627 Checking block size 64. 0.531524 0.531524 0.600285 0.600285 0.458787 0.458787 0.296680 0.296680 -0.911627 -0.911627 Checking block size 128. 0.531524 0.531524 0.600285 0.600285 0.458787 0.458787 0.296680 0.296680 -0.911627 -0.911627 Checking block size 256. 0.531524 0.531524 0.600285 0.600285 0.458787 0.458787 0.296680 0.296680 -0.911627 -0.911627 Checking block size 512. 0.531524 0.531524 0.600285 0.600285 0.458787 0.458787 0.296680 0.296680 -0.911627 -0.911627 Checking block size 1024. 0.531524 0.531524 0.600285 0.600285 0.458787 0.458787 0.296680 0.296680 -0.911627 -0.911627 All results match. Starting benchmarks. block_size 32 time 3.2461 ms block_size 64 time 1.7509 ms block_size 128 time 1.7374 ms block_size 256 time 1.7441 ms block_size 512 time 1.8092 ms block_size 1024 time 2.0443 ms ``` --- dev/cuda/Makefile | 2 +- dev/cuda/repkv_backward.cu | 61 ++++++++++++++++---------------------- 2 files changed, 27 insertions(+), 36 deletions(-) diff --git a/dev/cuda/Makefile b/dev/cuda/Makefile index 90ca0c993..f21973527 100644 --- a/dev/cuda/Makefile +++ b/dev/cuda/Makefile @@ -19,7 +19,7 @@ endif ifeq ($(GPU_COMPUTE_CAPABILITY),) # set to defaults if: make GPU_COMPUTE_CAPABILITY= CFLAGS = -O3 --use_fast_math else - CFLAGS = -DDEBUG -O3 --use_fast_math --generate-code arch=compute_$(GPU_COMPUTE_CAPABILITY),code=[compute_$(GPU_COMPUTE_CAPABILITY),sm_$(GPU_COMPUTE_CAPABILITY)] + CFLAGS = -O3 --use_fast_math --generate-code arch=compute_$(GPU_COMPUTE_CAPABILITY),code=[compute_$(GPU_COMPUTE_CAPABILITY),sm_$(GPU_COMPUTE_CAPABILITY)] endif NVCCFLAGS = -lcublas -lcublasLt -std=c++17 diff --git a/dev/cuda/repkv_backward.cu b/dev/cuda/repkv_backward.cu index e43b121e9..93f2720fa 100644 --- a/dev/cuda/repkv_backward.cu +++ b/dev/cuda/repkv_backward.cu @@ -1,7 +1,4 @@ /* - -TODO: update the description - Layer that takes a QKV tensor of shape (B, T, C) and replicates the K,V some number of times. For example, if B=4, T=64, C=6144, and we have that: - head dimension (hd) is 128 channels @@ -76,10 +73,8 @@ __global__ void repkv_backward_kernel1(floatX* dinp, // we want to reduce sum (for K and V) into (B, N, (NH + 2*(NH/replicate_factor)) * HD) int idx = blockIdx.x * blockDim.x + threadIdx.x; - // ?? if (idx >= B * N * 3 * NH * HD) { return;} - // ?? - int doutp_idx = idx; // keep backp + int doutp_idx = idx; // keep backup // decode the doutp index int d = idx % HD; @@ -92,32 +87,31 @@ __global__ void repkv_backward_kernel1(floatX* dinp, int b = idx / N; int dinp_idx; - // int nh_total = NH * 3; int nh_total = NH + 2 * (NH / replicate_factor); if (c == 0) { - dinp_idx = b * N * nh_total * HD + n * nh_total * HD + 0 * NH * HD * nh * HD + d; - dinp[dinp_idx] = __ldca(&doutp[doutp_idx]); + dinp_idx = b * N * nh_total * HD + n * nh_total * HD + 0 * NH * HD + nh * HD + d; + dinp[dinp_idx] = __ldcs(&doutp[doutp_idx]); } else if (c == 1) { - dinp_idx = b * N * nh_total * HD + n * nh_total * HD + 1 * NH * HD + (nh / replicate_factor) * HD + d; - // float reduced_sum = 0; - // if (doutp_idx % replicate_factor == 0) { - // for (int i = doutp_idx; i < doutp_idx+replicate_factor; i++) - // reduced_sum += __ldcs(&doutp[i]); - // dinp[dinp_idx] = reduced_sum; - // } - // ?? - dinp[dinp_idx] = __ldca(&doutp[doutp_idx]); + if (nh % replicate_factor == 0) { + float reduced_sum = 0; + for (int i = 0; i < replicate_factor; i++) { + reduced_sum += __ldcs(&doutp[doutp_idx+HD*i]); + } + + dinp_idx = b * N * nh_total * HD + n * nh_total * HD + 1 * NH * HD + (nh / replicate_factor) * HD + d; + dinp[dinp_idx] = reduced_sum; + } + } else { - dinp_idx = b * N * nh_total * HD + n * nh_total * HD + (NH * HD + (NH / replicate_factor) * HD) + (nh / replicate_factor) * HD + d; - // float reduced_sum = 0; - // if (doutp_idx % replicate_factor == 0) { - // for (int i = doutp_idx; i < doutp_idx + replicate_factor; i++) - // reduced_sum += __ldcs(&doutp[i]); - // dinp[dinp_idx] = reduced_sum; - // } - // ?? - dinp[dinp_idx] = __ldca(&doutp[doutp_idx]); + if (nh % replicate_factor == 0) { + float reduced_sum = 0; + for (int i = 0; i < replicate_factor; i++) { + reduced_sum += __ldcs(&doutp[doutp_idx+HD*i]); + } + dinp_idx = b * N * nh_total * HD + n * nh_total * HD + (NH * HD + (NH / replicate_factor) * HD) + (nh / replicate_factor) * HD + d; + dinp[dinp_idx] = reduced_sum; + } } } @@ -146,8 +140,8 @@ void repkv_backward(int kernel_num, } } -// TODO: update -void log_mat(float *inp, int B, int T, int C, int hd, int qh, int kh, int vh, char *title) +#ifdef DEBUG +static void log_mat(float *inp, int B, int T, int C, int hd, int qh, int kh, int vh, char *title) { printf("%s -----\n", title); for (int b = 0; b < B; b++) { @@ -183,6 +177,7 @@ void log_mat(float *inp, int B, int T, int C, int hd, int qh, int kh, int vh, ch } printf("\n"); } +#endif // DEBUG // tester int main(int argc, char **argv) { @@ -195,7 +190,6 @@ int main(int argc, char **argv) { int qh = 4; // num query heads int kh = 2; // num key heads int vh = 2; // num value heads - int nrep = qh/kh; #else int B = 8; int T = 1024; @@ -203,7 +197,6 @@ int main(int argc, char **argv) { int qh = 32; // num query heads int kh = 8; // num key heads int vh = 8; // num value heads - int nrep = qh/kh; #endif int deviceIdx = 0; @@ -234,6 +227,7 @@ int main(int argc, char **argv) { printf("Using kernel %d\n", kernel_num); #ifdef DEBUG + int nrep = qh/kh; log_mat(doutp, B, T, Cout, hd, qh, nrep*kh, nrep*vh, "doutp"); #endif // DEBUG @@ -256,17 +250,14 @@ int main(int argc, char **argv) { } printf("All results match. Starting benchmarks.\n\n"); - // TODO: update -#if 0 // now benchmark for (int j = 0; j < sizeof(block_sizes) / sizeof(int); j++) { int block_size = block_sizes[j]; int repeat_times = 1000; float elapsed_time = benchmark_kernel(repeat_times, repkv_backward, kernel_num, - d_out, d_inp, B, T, qh, kh, hd, block_size); + d_dinp, d_inp, d_doutp, B, T, qh, kh, hd, block_size); printf("block_size %4d time %.4f ms\n", block_size, elapsed_time); } -#endif // free memory free(inp); From 31be5e790d249c9a6a7d8c1916c5bed5d6c47d6a Mon Sep 17 00:00:00 2001 From: Andrej Karpathy Date: Thu, 26 Sep 2024 03:17:58 +0000 Subject: [PATCH 27/63] add rmsnorm backward in dev/cuda, it seems to work surprisingly, and is probably ready to be integrated into llmc. we are still using 2X too much shared memory because I didn't want to change way too many things at the same time. I copy pasted our kernel10 of layernorm backward and made tweaks to it removing the bias and mean cool --- dev/cbridge/rmsnorm.py | 101 ++++++++++ dev/cuda/rmsnorm_backward.cu | 355 +++++++++++++++++++++++++++++++++++ 2 files changed, 456 insertions(+) create mode 100644 dev/cbridge/rmsnorm.py create mode 100644 dev/cuda/rmsnorm_backward.cu diff --git a/dev/cbridge/rmsnorm.py b/dev/cbridge/rmsnorm.py new file mode 100644 index 000000000..16fe5b86f --- /dev/null +++ b/dev/cbridge/rmsnorm.py @@ -0,0 +1,101 @@ +""" +An RMSNorm PyTorch reference implementation. +This script then does forward/back and writes everything to file so we can +develop the CPU version, and eventually the GPU kernel as well. +""" + +import math +import torch +import numpy as np +import torch.nn as nn +from torch.nn import functional as F + +# ----------------------------------------------------------------------------- + +class RMSNorm(torch.nn.Module): + def __init__(self, dim: int, eps: float): + super().__init__() + self.eps = eps + self.weight = nn.Parameter(torch.ones(dim)) + + def _norm(self, x): + mean_sq = x.pow(2).mean(dim=-1, keepdim=True) + self.eps + rstd = torch.rsqrt(mean_sq) + norm = x * rstd + return norm + + def forward(self, x): + output = self._norm(x.float()).type_as(x) + return output * self.weight + +def rmsnorm_backward(x, w, dout, eps): + # recompute the rstd, norm (or we could cache it in the forward pass) + mean_sq = x.pow(2).mean(dim=-1, keepdim=True) + eps # (B, T, 1) + rstd = torch.rsqrt(mean_sq) # (B, T, 1) + norm = x * rstd # (B, T, C) + # gradients for weights + dw = (dout * norm).sum((0, 1)) # (C) + # gradients for input + dnorm = dout * w # (B, T, C) + dx = dnorm - norm * (dnorm * norm).mean(dim=-1, keepdim=True) + dx *= rstd + return dx, dw + +# ----------------------------------------------------------------------------- + +# seed the rng +torch.manual_seed(42) + +B = 4 +T = 64 +C = 256 +eps = 1e-5 + +inp = torch.randn(B, T, C, dtype=torch.float32) +inp.requires_grad = True + +# rmsnorm +m = RMSNorm(C, eps=eps) +out = m(inp) + +# loss can just be a weighted sum, with some fixed weights +wei = torch.randn_like(out, dtype=torch.float32) +loss = (out * wei).sum() +loss.backward() + +# let's now do the backward pass manually +# backprop starts with the output gradient, which is exactly wei because of the loss functions +dx, dw = rmsnorm_backward(inp, m.weight, wei, eps) +# let's assert that the gradients match +assert torch.allclose(dx, inp.grad, atol=1e-6) +assert torch.allclose(dw, m.weight.grad, atol=1e-6) +print("RMSNorm gradients match") +print("first 5 elements of dx comparison:") +print(dx.view(-1)[:5].tolist()) +print(inp.grad.view(-1)[:5].tolist()) +print("first 5 elements of dw comparison:") +print(dw.view(-1)[:5].tolist()) +print(m.weight.grad.view(-1)[:5].tolist()) +print("dx error:", (inp.grad.view(-1) - dx.view(-1)).abs().max().item()) +print("dw error:", (m.weight.grad.view(-1) - dw.view(-1)).abs().max().item()) + +# save to .bin file so we can check correctness in C land +int_header = np.zeros(16, dtype=np.int32) # for ints +float_header = np.zeros(16, dtype=np.float32) # for floats +int_header[0] = 20240925 # magic number +int_header[1] = B +int_header[2] = T +int_header[3] = C +float_header[0] = eps + +# write the hyperparameters, inputs, output, and input gradients to file +results_file = "rmsnorm.bin" +with open(results_file, "wb") as f: + f.write(int_header.tobytes()) # 16 int32 + f.write(float_header.tobytes()) # 16 float32 + f.write(inp.detach().cpu().numpy().tobytes()) # B * T * C + f.write(out.detach().cpu().numpy().tobytes()) # B * T * C + f.write(wei.detach().cpu().numpy().tobytes()) # B * T * C + f.write(inp.grad.detach().cpu().numpy().tobytes()) # B * T * C + f.write(m.weight.grad.detach().cpu().numpy().tobytes()) # C +print("Saved results to %s" % results_file) diff --git a/dev/cuda/rmsnorm_backward.cu b/dev/cuda/rmsnorm_backward.cu new file mode 100644 index 000000000..70868da01 --- /dev/null +++ b/dev/cuda/rmsnorm_backward.cu @@ -0,0 +1,355 @@ +/* +Kernels for RMSNorm backward pass. + +Compile example: +nvcc -O3 --use_fast_math -lcublas -lcublasLt rmsnorm_backward.cu -o rmsnorm_backward + +./rmsnorm_backward 1 +*/ + +#include +#include +#include +#include + +#define ENABLE_BF16 +#include "common.h" + +// ---------------------------------------------------------------------------- +// CPU code reference + +void rmsnorm_forward_cpu(float* out, float* rstd, + const float* inp, const float* weight, + int B, int T, int C) { + float eps = 1e-5f; + for (int b = 0; b < B; b++) { + for (int t = 0; t < T; t++) { + // seek to the input position inp[b,t,:] + const float* x = inp + b * T * C + t * C; + // calculate the variance (without any bias correction) + float v = 0.0f; + for (int i = 0; i < C; i++) { + float xi = x[i]; + v += xi * xi; + } + v = v/C; + // calculate the rstd (reciprocal standard deviation) + float s = 1.0f / sqrtf(v + eps); + // seek to the output position in out[b,t,:] + float* out_bt = out + b * T * C + t * C; + for (int i = 0; i < C; i++) { + float n = (s * x[i]); // normalize + float o = n * weight[i]; // scale and shift + out_bt[i] = o; // write + } + // cache the rstd for the backward pass later + rstd[b * T + t] = s; + } + } +} + +void rmsnorm_backward_cpu(float* dinp, float* dweight, + const float* dout, const float* inp, const float* weight, const float* rstd, + int B, int T, int C) { + for (int b = 0; b < B; b++) { + for (int t = 0; t < T; t++) { + const float* dout_bt = dout + b * T * C + t * C; + const float* inp_bt = inp + b * T * C + t * C; + float* dinp_bt = dinp + b * T * C + t * C; + const float rstd_bt = rstd[b * T + t]; + + // first: the reduce operation + float dnorm_norm_mean = 0.0f; + for (int i = 0; i < C; i++) { + float norm_bti = inp_bt[i] * rstd_bt; + float dnorm_i = weight[i] * dout_bt[i]; + dnorm_norm_mean += dnorm_i * norm_bti; + } + dnorm_norm_mean = dnorm_norm_mean / C; + + // now iterate again and accumulate all the gradients + for (int i = 0; i < C; i++) { + float norm_bti = inp_bt[i] * rstd_bt; + float dnorm_i = weight[i] * dout_bt[i]; + // gradient contribution to weight + dweight[i] += norm_bti * dout_bt[i]; + // gradient contribution to input + float dval = 0.0f; + dval += dnorm_i; // term 1 + dval -= norm_bti * dnorm_norm_mean; // term 2 + dval *= rstd_bt; // final scale + dinp_bt[i] += dval; + } + } + } +} + +// ---------------------------------------------------------------------------- +// GPU kernel + +__global__ void __launch_bounds__(512, 2) // todo - any warnings on Turing with only 1024 threads? + rmsnorm_backward_kernel10(floatX* dinp, floatX* dweight, float* scratch, + const floatX* dout, const floatX* inp, const floatX* weight, + const float* rstd, int B, int T, int C) { + // TODO: this kernel uses too much shared memory due to historical reasons of it coming from layernorm_backward.cu + // this memory use can be reduced by half later + int BLOCK_SIZE = blockDim.x; + int warpsInBlock = BLOCK_SIZE / WARP_SIZE; //number of warps in block + extern __shared__ float shared[]; + + int warpId = threadIdx.x / WARP_SIZE; // warp index within a block + int baseIdx = blockIdx.x * warpsInBlock + warpId; + int warpThreadIdx = threadIdx.x % WARP_SIZE; // Thread index within the warp + int warpsInGrid = gridDim.x * warpsInBlock; + int C_per_iteration = WARP_SIZE * x128::size; + int iterations_C = ceil_div(C, C_per_iteration); // + 2; + + // the first half of shared memory is bias, second is weight + size_t rounded_C = ceil_div(C, (32 * x128::size)) * (32 * x128::size); + float* dweight_shared = shared + rounded_C; + // warp zero doesn't actually write to the _tmp_shared memory locations, so we don't need to reserve memory + // the obvious solution is to change the addressing below to use (threadId.x-32) as offset, but that causes + // register spills, so instead we mess with the base pointer here, which doesn't increase register usage. + float* dweight_tmp_shared = shared + 2 * rounded_C + f128::size * BLOCK_SIZE - 2 * WARP_SIZE * f128::size; + + // init shared memory to zero + for(int i = threadIdx.x * f128::size; i < rounded_C; i += BLOCK_SIZE * f128::size) { + store128(dweight_shared + i, f128::zeros()); + } + __syncthreads(); + + for (int bt = baseIdx; bt < B * T; bt += warpsInGrid) { + const floatX* dout_bt = dout + bt * C; + const floatX* inp_bt = inp +bt * C; + floatX* dinp_bt = dinp + bt * C; + + // first: two reduce operations + float dnorm_mean = 0.0f; + float dnorm_norm_mean = 0.0f; + for (int i = warpThreadIdx * x128::size; i < C; i += WARP_SIZE * x128::size) { + x128 dout128_i = load128(dout_bt + i); + x128 inp128_i = load128(inp_bt + i); + x128 weight128_i = load128(weight + i); + for (int k = 0; k < x128::size; k++) { + float dnorm_i = (float)weight128_i[k] * (float)dout128_i[k]; + dnorm_mean += dnorm_i; + dnorm_norm_mean += dnorm_i * (float)inp128_i[k]; + } + } + + const float rstd_bt = rstd[bt]; + dnorm_norm_mean = warpReduceSum(dnorm_norm_mean) / C * rstd_bt; + + for (int c = 0; c < iterations_C; c++) { + int global_index = (warpThreadIdx * x128::size) + (c * C_per_iteration); + + x128 dout128 = x128::zeros(); + x128 inp128 = x128::zeros(); + x128 dinp128 = x128::zeros(); + x128 weight128 = x128::zeros(); + + if(global_index < C) { + dout128 = load128cs(dout_bt + global_index); + inp128 = load128cs(inp_bt + global_index); + dinp128 = load128(dinp_bt + global_index); + weight128 = load128(weight + global_index); + } + + for(int o = 0; o < x128::size / f128::size; ++o) { + f128 dweight_f; + for(int i = 0; i < f128::size; ++i) { + int x = o * f128::size + i; + float dout_i = (float)dout128[x]; + float norm_bti = ((float)inp128[x]) * rstd_bt; + dweight_f[i] = norm_bti * dout_i; + + float dval = 0.0f; + dval += (float) weight128[x] * (float)dout128[x]; // term 1 + dval -= norm_bti * dnorm_norm_mean; // term 2 + dval *= rstd_bt; // final scale + dinp128[x] = (floatX) ((float) dinp128[x] + dval); + } + + if (warpId != 0) { + // this seems to generate a 64-bit store, instead of 128-bit. + // however, forcing 128-bit (e.g., using inline ptx), results in register + // spilling and much worse performance, so we'll keep it like this for now + // but ideally, we could reduce the register pressure a little. + store128(dweight_tmp_shared + threadIdx.x * f128::size, dweight_f); + } + __syncthreads(); + if (warpId == 0) { + for (int j = 1; j < warpsInBlock; j++) { + f128 dweight_tmp = load128(dweight_tmp_shared + f128::size * (threadIdx.x + j * WARP_SIZE)); + for(int i = 0; i < f128::size; ++i) { + dweight_f[i] += dweight_tmp[i]; + } + } + } + __syncthreads(); + if (warpId == 0) { + f128 dw_old = load128(dweight_shared + global_index + f128::size * o); + for(int i = 0; i < f128::size; ++i) { + dweight_f[i] += dw_old[i]; + } + store128(dweight_shared + global_index + f128::size * o, dweight_f); + } + } + if(global_index < C) { + // cache in L2 as this is read by the next kernel, but bypass L1 to minimise thrashing + store128cg(dinp_bt + global_index, dinp128); + } + } + } + __syncthreads(); + // Each block writes its partial sum to global memory + // The last block to finish becomes responsible for summing up all the partial sums + // This is done by atomically incrementing a flag (cleared to 0 before launching the kernel) + unsigned int* scratchFlag = (unsigned int*)(scratch); + // Increment scratch pointer by a full cacheline so that everything remains cacheline aligned + scratch += 32; + float* scratch_dweight = scratch + C; + for(int i = threadIdx.x * f128::size; i < C; i += BLOCK_SIZE * f128::size) { + // Write to global memory in the same "shared memory banking friendly" order + store128(scratch_dweight + i + 2*C*blockIdx.x, load128(dweight_shared + i)); + } + __syncthreads(); + // that portion of shared memory is no longer used, so we can repurpose it for the scratch flag. + unsigned int *tmp_flag = (unsigned int*)(shared + 2*rounded_C); + if (threadIdx.x == 0) { + *tmp_flag = atomicInc(scratchFlag, gridDim.x); + } + __syncthreads(); + if (*tmp_flag == gridDim.x-1) { + // Reduction of the partial sums by the final block + // todo - there isn't enough parallelism even inside that single SM... + // ==> so could maybe split into another kernel with YET ANOTHER level of reduction?! + for(int i = threadIdx.x * f128::size; i < C; i += BLOCK_SIZE * f128::size) { + f128 dweight_accum = f128::zeros(); + + for (int read_block_idx = 0; read_block_idx < gridDim.x; read_block_idx++) { + int offset = i + 2*C*read_block_idx; + f128 dweight128 = load128(scratch_dweight + offset); + for(int k = 0; k < f128::size; k++) { + dweight_accum[k] += dweight128[k]; + } + } + store128(dweight_shared + i, dweight_accum); + } + __syncthreads(); + + // convert from float/FP32 to floatX/BF16 for the final write + // this is separate because it cannot use as many warps as the above (f128 vs x128) + // todo - if we split this code into another kernel, we could maybe do it at the same time? + for (int c = warpId; c < iterations_C; c += warpsInBlock) { + int global_index = (warpThreadIdx * x128::size) + (c * C_per_iteration); + if (global_index >= C) { + break; + } + x128 dweight128 = load128(dweight + global_index); + for(int o = 0; o < x128::size / f128::size; ++o) { + f128 s_dw = load128(dweight_shared + global_index + o * f128::size); + for(int i = 0; i < f128::size; ++i) { + int x = o * f128::size + i; + dweight128[x] = (floatX)(s_dw[i] + (float)dweight128[x]); + } + } + store128(dweight + global_index, dweight128); + } + } +} + +// ---------------------------------------------------------------------------- +// Kernel launcher + +void rmsnorm_backward(floatX* dinp, floatX* dweight, float* scratch, + const floatX* dout, const floatX* inp, const floatX* weight, const float* rstd, + int B, int T, int C, cudaStream_t stream) { + const int block_size = 512; + const int blocks_per_sm = 2; // supported on every architecture and less cache thrashing than 3 + // const int grid_size = blocks_per_sm * deviceProp.multiProcessorCount; + const int grid_size = blocks_per_sm * cuda_num_SMs; + size_t rounded_C = ceil_div(C, (32 * x128::size)) * (32 * x128::size); + size_t shared_mem_size = (2 * rounded_C + 2 * (block_size - 32) * f128::size) * sizeof(float); + + cudaCheck(cudaMemsetAsync(scratch, 0, 1 * sizeof(float), stream)); // only need to reset the flag to 0 + rmsnorm_backward_kernel10<<>>(dinp, dweight, scratch, dout, inp, weight, rstd, B, T, C); + cudaCheck(cudaGetLastError()); +} + +// ---------------------------------------------------------------------------- + +int main(int argc, char **argv) { + setup_main(); + + int B = 8; + int T = 1024; + int C = 1024; + + // first do the forward pass in CPU + float* out = (float*)malloc(B * T * C * sizeof(float)); + float* rstd = (float*)malloc(B * T * sizeof(float)); + float* inp = make_random_float(B * T * C); + float* weight = make_random_float(C); + rmsnorm_forward_cpu(out, rstd, inp, weight, B, T, C); + + // now do the backward pass, again on CPU + float *dout = make_random_float(B * T * C); + float *dinp = make_zeros_float(B * T * C); + float *dweight = make_zeros_float(C); + rmsnorm_backward_cpu(dinp, dweight, dout, inp, weight, rstd, B, T, C); + + // the above calculations act as the reference + // now let's do the same on the GPU + + // read kernel_num from command line + int kernel_num = 1; + if (argc > 1) { + kernel_num = atoi(argv[1]); + } + printf("Using kernel %d\n", kernel_num); + + // move all the variables we need for backward pass onto the GPU + floatX* d_dinp; + floatX* d_dweight; + floatX* d_dout; + floatX* d_inp; + floatX* d_weight; + float* d_rstd; + float* d_scratch; + cudaCheck(cudaMalloc(&d_dinp, B * T * C * sizeof(floatX))); + cudaCheck(cudaMalloc(&d_dweight, C * sizeof(floatX))); + cudaCheck(cudaMalloc(&d_dout, B * T * C * sizeof(floatX))); + cudaCheck(cudaMalloc(&d_inp, B * T * C * sizeof(floatX))); + cudaCheck(cudaMalloc(&d_weight, C * sizeof(floatX))); + cudaCheck(cudaMalloc(&d_rstd, B * T * sizeof(float))); + cudaCheck(cudaMalloc(&d_scratch, (1024/32) * cuda_num_SMs * (2 * C + 1) * sizeof(float))); + + // copy over the "inputs" to the backward call + cudaCheck(memcpy_convert(d_dout, dout, B * T * C)); + cudaCheck(memcpy_convert(d_inp, inp, B * T * C)); + cudaCheck(memcpy_convert(d_weight, weight, C)); + cudaCheck(memcpy_convert(d_rstd, rstd, B * T)); + + // launch the kernel + int block_sizes[] = {32, 64, 128, 256, 512, /*768,*/ 1024}; + for (int j = 0; j < sizeof(block_sizes) / sizeof(int); j++) { + int block_size = block_sizes[j]; + // init the "outputs" of the backward call to zeros + cudaCheck(cudaMemset(d_dinp, 0, B * T * C * sizeof(floatX))); + cudaCheck(cudaMemset(d_dweight, 0, C * sizeof(floatX))); + + rmsnorm_backward(d_dinp, d_dweight, d_scratch, d_dout, d_inp, d_weight, d_rstd, B, T, C, 0); + + // check the correctness of the kernel + float error_threshold_dinp = sizeof(floatX) == 4 ? 1e-3f : 1e-1f; // allow larger errors for BF16/FP16 + float error_threshold_dparams = sizeof(floatX) == 4 ? 1e-3f : 5e-1f; // much, much larger... + printf("Checking correctness...\n"); + printf("dinp:\n"); + validate_result(d_dinp, dinp, "dinp", B * T * C, error_threshold_dinp); + printf("dweight:\n"); + validate_result(d_dweight, dweight, "dweight", C, error_threshold_dparams); + + printf("All results match for block_size=%d.\n\n", block_size); + } +} From 102067fbf5d3a389434520b46865d85a73cbca92 Mon Sep 17 00:00:00 2001 From: Andrej Karpathy Date: Thu, 26 Sep 2024 22:08:19 +0000 Subject: [PATCH 28/63] oops i think i accidentally forgot to include swiglu.cuh --- llmc/swiglu.cuh | 86 +++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 86 insertions(+) create mode 100644 llmc/swiglu.cuh diff --git a/llmc/swiglu.cuh b/llmc/swiglu.cuh new file mode 100644 index 000000000..a574233d3 --- /dev/null +++ b/llmc/swiglu.cuh @@ -0,0 +1,86 @@ +/* +SwiGLU activation function +Unlike GeLU, SwiGLU is a bit more tricky because there are two separate linear layers. +In PyTorch we have: + +self.c_fc = nn.Linear(config.n_embd, hidden_dim, bias=False) +self.c_fc2 = nn.Linear(config.n_embd, hidden_dim, bias=False) + +and then: + +x1 = self.c_fc(x) +x2 = self.c_fc2(x) +x2 = F.silu(x2) +x = x1 * x2 + +But in our implementation to minimize the amount of changes, we have the weights of +the two linear layers concatenated together. So in this non-linearity, we receive +as input the conctatenation of [x1, x2], and our job is just to apply silu and +elementwise multiply. And we have to be careful because the output size is half +the input size! +*/ + +#include +// llmc internal imports +#include "cuda_common.h" +#include "cuda_utils.cuh" + +// ---------------------------------------------------------------------------- +// CUDA kernels + +__global__ void swiglu_forward_kernel1(floatX* out, const floatX* inp, int B, int T, int C) { + int idx = (blockIdx.x * blockDim.x + threadIdx.x) * x128::size; + floatX* out_ptr = out + idx; + // b,t,c in the output + int b = idx / (T * C); + int t = (idx / C) % T; + int c = idx % C; + + int C2 = C * 2; + const floatX* inp1_ptr = inp + (b * T * C2 + t * C2 + c); + const floatX* inp2_ptr = inp1_ptr + C; + + x128 packed_out; + x128 packed_inp1 = load128cs(inp1_ptr); // fc1 + x128 packed_inp2 = load128cs(inp2_ptr); // fc2 + for(int k = 0; k < packed_inp1.size; ++k) { + float x1 = (float)packed_inp1[k]; + float x2 = (float)packed_inp2[k]; + packed_out[k] = (floatX)((x1 * x2) / (1.0f + expf(-x2))); + } + store128(out_ptr, packed_out); +} + +__global__ void swiglu_forward_kernel2(floatX* out, const floatX* inp, int B, int T, int C) { + int idx = blockIdx.x * blockDim.x + threadIdx.x; + // derive the b,t,c from idx + int b = idx / (T * C); + int t = (idx / C) % T; + int c = idx % C; + int C2 = C * 2; + float x1 = (float) inp[b * T * C2 + t * C2 + c]; + float x2 = (float) inp[b * T * C2 + t * C2 + c + C]; + out[idx] = (floatX)((x1 * x2) / (1.0f + expf(-x2))); +} + +// ---------------------------------------------------------------------------- +// kernel launchers + +void swiglu_forward(floatX* out, const floatX* inp, int B, int T, int C, cudaStream_t stream) { + // input is (B, T, 2C), output is (B, T, C) + // we have that inp[b, t, :] = [fc1, fc2] (i.e. they are concatenated in each C-fiber) + NVTX_RANGE_FN(); + const int block_size = 128; + assert((B*T*C) % (block_size * x128::size) == 0); + const int grid_size = CEIL_DIV(B*T*C, block_size * x128::size); + swiglu_forward_kernel1<<>>(out, inp, B, T, C); + cudaCheck(cudaGetLastError()); +} + +void swiglu_forward_naive(floatX* out, const floatX* inp, int B, int T, int C, cudaStream_t stream) { + // same as above but no x128 packing to be SAFE + const int block_size = 128; + const int grid_size = CEIL_DIV(B*T*C, block_size); + swiglu_forward_kernel2<<>>(out, inp, B, T, C); + cudaCheck(cudaGetLastError()); +} From 2c4b3cc8bb4dc5565965f73b439cc642612e7913 Mon Sep 17 00:00:00 2001 From: Andrej Karpathy Date: Thu, 26 Sep 2024 22:27:58 +0000 Subject: [PATCH 29/63] integrate our rmsnorm backward and move the other rmsnorm functions into rmsnorm.cuh that is a new file --- llmc/layernorm.cuh | 179 --------------------- llmc/rmsnorm.cuh | 380 +++++++++++++++++++++++++++++++++++++++++++++ train_llama3.cu | 46 +++--- train_llama3.py | 28 ++-- 4 files changed, 419 insertions(+), 214 deletions(-) create mode 100644 llmc/rmsnorm.cuh diff --git a/llmc/layernorm.cuh b/llmc/layernorm.cuh index 41f9ec6d3..9777d0658 100644 --- a/llmc/layernorm.cuh +++ b/llmc/layernorm.cuh @@ -139,66 +139,6 @@ __global__ void layernorm_forward_kernel6(floatX* __restrict__ out, float* __res } } -__global__ void rmsnorm_forward_kernel6(floatX* __restrict__ out, float* __restrict__ rms, - const floatX* __restrict__ inp, const floatX* __restrict__ weight, int N, int C) { - // this kernel is a simplified version of layernorm_forward_kernel6 - assert(blockDim.x == WARP_SIZE); - - // load weights into shared memory - // do this before we allow any threads to exit! - extern __shared__ char* params[]; - // load128/store128 sometimes generated multiple instructions when the types here were floatX*, so - // let's keep everything as x128 - x128* s_weight = reinterpret_cast(params); - x128* s_in = reinterpret_cast(params) + ((1 + threadIdx.y) * C / x128::size); - - int sidx = (threadIdx.x + WARP_SIZE * threadIdx.y) * x128::size; - for(int i = sidx; i < C; i += blockDim.y * WARP_SIZE * x128::size) { - s_weight[i/x128::size] = load128(weight + i); - } - __syncthreads(); - - int idx = blockIdx.x * blockDim.y + threadIdx.y; - if(idx >= N) { return; } // guard - - // adjust pointers to current token - inp += idx * C; - out += idx * C; - - const float eps = 1e-5f; - float acc = 0.f; - - for(int c = threadIdx.x * x128::size; c < C; c += WARP_SIZE * x128::size) { - const x128 in_data = load128cs(inp + c); - s_in[c / x128::size] = in_data; - for(int k = 0; k < x128::size; ++k) { - float data_k = (float)in_data[k]; - acc += data_k * data_k; - } - } - - acc = warpReduceSum(acc) / C; - float s = rsqrtf(acc + eps); - - for(int c = threadIdx.x * x128::size; c < C; c += WARP_SIZE * x128::size) { - const x128 in_data = s_in[c / x128::size]; - const x128 w = s_weight[c / x128::size]; - x128 out_data; - for(int k = 0; k < x128::size; ++k) { - float n = s * (float)in_data[k]; // normalized output - float o = n * (float)w[k]; // scale - out_data[k] = (floatX)o; - } - - store128cs(out + c, out_data); - } - - // store the rms, no need to cache it - if(threadIdx.x == 0 && rms != nullptr) { - __stcs(rms + idx, s); - } -} - __global__ void fused_residual_forward_kernel5(floatX* residual, floatX* normed, float* mean, float* rstd, const floatX* inp1, const floatX* inp2, const floatX* weight, const floatX* bias, @@ -278,77 +218,6 @@ __global__ void fused_residual_forward_kernel5(floatX* residual, floatX* normed, } } -__global__ void fused_residual_rmsnorm_forward_kernel5(floatX* residual, floatX* normed, float* rrms, - const floatX* inp1, const floatX* inp2, - const floatX* weight, - int N, int C) { - assert(blockDim.x == WARP_SIZE); - - // load weights and biases into shared memory - // do this before we allow any threads to exit! - extern __shared__ char* params[]; - // load128/store128 sometimes generated multiple instructions when the types here were floatX*, so - // let's keep everything as x128 - x128* s_weight = reinterpret_cast(params); - x128* s_res = reinterpret_cast(params) + ((1 + threadIdx.y) * C / x128::size); - - int sidx = (threadIdx.x + WARP_SIZE * threadIdx.y) * x128::size; - for(int i = sidx; i < C; i += blockDim.y * WARP_SIZE * x128::size) { - s_weight[i/x128::size] = load128(weight + i); - } - __syncthreads(); - - int idx = blockIdx.x * blockDim.y + threadIdx.y; - if(idx > N) return; - - // adjust pointers to current token - residual += C * idx; - normed += C * idx; - inp1 += C * idx; - inp2 += C * idx; - - const float eps = 1e-5f; - for(int c = threadIdx.x * x128::size; c < C; c += WARP_SIZE * x128::size) { - const x128 in1 = load128cs(inp1 + c); - const x128 in2 = load128cs(inp2 + c); - x128 out; - for(int k = 0; k < x128::size; ++k) { - out[k] = (float)in1[k] + (float)in2[k]; - } - store128cs(residual + c, out); - s_res[c / x128::size] = out; - } - - float v = 0.f; - - for(int c = threadIdx.x * x128::size; c < C; c += WARP_SIZE * x128::size) { - const x128 res = s_res[c / x128::size]; - for(int k = 0; k < x128::size; ++k) { - v += (float)res[k] * (float)res[k]; - } - } - - v = warpReduceSum(v) / C; - float s = rsqrtf(v + eps); - - for(int c = threadIdx.x * x128::size; c < C; c += WARP_SIZE * x128::size) { - const x128 res = s_res[c / x128::size]; - const x128 w = s_weight[c / x128::size]; - x128 out; - for(int k = 0; k < x128::size; ++k) { - float n = s * (float)res[k]; // normalized output - float o = n * (float)w[k]; // scale - out[k] = o; - } - - store128cs(normed + c, out); - } - // cache the rrms for the backward pass later - if(threadIdx.x == 0) { - rrms[idx] = s; - } -} - __global__ void residual_forward_kernel(floatX* out, const floatX* inp1, const floatX* inp2) { int idx = (blockIdx.x * blockDim.x + threadIdx.x) * x128::size; @@ -620,30 +489,6 @@ void fused_residual_forward5(floatX* residual, floatX* normed, float* mean, floa cudaCheck(cudaGetLastError()); } -void fused_residual_rmsnorm_forward5(floatX* residual, floatX* normed, float* rrms, - const floatX* inp1, const floatX* inp2, - const floatX* weight, - int N, int C, cudaStream_t stream) { - const int block_size = 256; - int block_y = block_size / WARP_SIZE; - const int grid_size = CEIL_DIV(N, block_y); - size_t smem = (1 + block_y) * C * sizeof(floatX); - - // in order to use more than 48 KiB of smem, need to call cudaFuncSetAttribute - // this may fail, in which case we fall back to the smem free implementation. - cudaCheck(cudaGetLastError()); - auto status = cudaFuncSetAttribute(fused_residual_rmsnorm_forward_kernel5, cudaFuncAttributeMaxDynamicSharedMemorySize, smem); - cudaCheck(cudaGetLastError()); - if(status == cudaSuccess) { - fused_residual_rmsnorm_forward_kernel5<<>>(residual, normed, - rrms, inp1, inp2, - weight, N, C); - } else { - assert(false); - } - cudaCheck(cudaGetLastError()); -} - void layernorm_backward(floatX* dinp, floatX* dweight, floatX* dbias, float* scratch, const floatX* dout, const floatX* inp, const floatX* weight, const float* mean, const float* rstd, int B, int T, int C, cudaStream_t stream) { @@ -658,27 +503,3 @@ void layernorm_backward(floatX* dinp, floatX* dweight, floatX* dbias, float* scr layernorm_backward_kernel10<<>>(dinp, dweight, dbias, scratch, dout, inp, weight, mean, rstd, B, T, C); cudaCheck(cudaGetLastError()); } - -void rmsnorm_forward(floatX* out, float* rms, - floatX* inp, const floatX* weight, - int B, int T, int C, cudaStream_t stream) { - NVTX_RANGE_FN(); - const int block_size = 256; - int block_y = block_size / WARP_SIZE; - const int N = B * T; - const int grid_size = CEIL_DIV(N, block_y); - size_t smem = (1 + block_y) * C * sizeof(floatX); - - // in order to use more than 48 KiB of smem, need to call cudaFuncSetAttribute - // this may fail, in which case we fall back to the smem free implementation. - cudaCheck(cudaGetLastError()); - auto status = cudaFuncSetAttribute(rmsnorm_forward_kernel6, cudaFuncAttributeMaxDynamicSharedMemorySize, smem); - cudaCheck(cudaGetLastError()); - if (status == cudaSuccess) { - rmsnorm_forward_kernel6<<>>(out, rms, inp, weight, N, C); - } else { - // We should not allow for these perf regressions for now - just throw an error - assert(false); - } - cudaCheck(cudaGetLastError()); -} diff --git a/llmc/rmsnorm.cuh b/llmc/rmsnorm.cuh new file mode 100644 index 000000000..8f20e9864 --- /dev/null +++ b/llmc/rmsnorm.cuh @@ -0,0 +1,380 @@ +/* +RMSNorm backward CUDA kernel. +*/ + +#include +#include "cuda_common.h" +#include "cuda_utils.cuh" + +// ---------------------------------------------------------------------------- +// CUDA kernels + +__global__ void rmsnorm_forward_kernel6(floatX* __restrict__ out, float* __restrict__ rms, + const floatX* __restrict__ inp, const floatX* __restrict__ weight, int N, int C) { + // this kernel is a simplified version of layernorm_forward_kernel6 + assert(blockDim.x == WARP_SIZE); + + // load weights into shared memory + // do this before we allow any threads to exit! + extern __shared__ char* params[]; + // load128/store128 sometimes generated multiple instructions when the types here were floatX*, so + // let's keep everything as x128 + x128* s_weight = reinterpret_cast(params); + x128* s_in = reinterpret_cast(params) + ((1 + threadIdx.y) * C / x128::size); + + int sidx = (threadIdx.x + WARP_SIZE * threadIdx.y) * x128::size; + for(int i = sidx; i < C; i += blockDim.y * WARP_SIZE * x128::size) { + s_weight[i/x128::size] = load128(weight + i); + } + __syncthreads(); + + int idx = blockIdx.x * blockDim.y + threadIdx.y; + if(idx >= N) { return; } // guard + + // adjust pointers to current token + inp += idx * C; + out += idx * C; + + const float eps = 1e-5f; + float acc = 0.f; + + for(int c = threadIdx.x * x128::size; c < C; c += WARP_SIZE * x128::size) { + const x128 in_data = load128cs(inp + c); + s_in[c / x128::size] = in_data; + for(int k = 0; k < x128::size; ++k) { + float data_k = (float)in_data[k]; + acc += data_k * data_k; + } + } + + acc = warpReduceSum(acc) / C; + float s = rsqrtf(acc + eps); + + for(int c = threadIdx.x * x128::size; c < C; c += WARP_SIZE * x128::size) { + const x128 in_data = s_in[c / x128::size]; + const x128 w = s_weight[c / x128::size]; + x128 out_data; + for(int k = 0; k < x128::size; ++k) { + float n = s * (float)in_data[k]; // normalized output + float o = n * (float)w[k]; // scale + out_data[k] = (floatX)o; + } + + store128cs(out + c, out_data); + } + + // store the rms, no need to cache it + if(threadIdx.x == 0 && rms != nullptr) { + __stcs(rms + idx, s); + } +} + +__global__ void fused_residual_rmsnorm_forward_kernel5(floatX* residual, floatX* normed, float* rrms, + const floatX* inp1, const floatX* inp2, + const floatX* weight, + int N, int C) { + assert(blockDim.x == WARP_SIZE); + + // load weights and biases into shared memory + // do this before we allow any threads to exit! + extern __shared__ char* params[]; + // load128/store128 sometimes generated multiple instructions when the types here were floatX*, so + // let's keep everything as x128 + x128* s_weight = reinterpret_cast(params); + x128* s_res = reinterpret_cast(params) + ((1 + threadIdx.y) * C / x128::size); + + int sidx = (threadIdx.x + WARP_SIZE * threadIdx.y) * x128::size; + for(int i = sidx; i < C; i += blockDim.y * WARP_SIZE * x128::size) { + s_weight[i/x128::size] = load128(weight + i); + } + __syncthreads(); + + int idx = blockIdx.x * blockDim.y + threadIdx.y; + if(idx > N) return; + + // adjust pointers to current token + residual += C * idx; + normed += C * idx; + inp1 += C * idx; + inp2 += C * idx; + + const float eps = 1e-5f; + for(int c = threadIdx.x * x128::size; c < C; c += WARP_SIZE * x128::size) { + const x128 in1 = load128cs(inp1 + c); + const x128 in2 = load128cs(inp2 + c); + x128 out; + for(int k = 0; k < x128::size; ++k) { + out[k] = (float)in1[k] + (float)in2[k]; + } + store128cs(residual + c, out); + s_res[c / x128::size] = out; + } + + float v = 0.f; + + for(int c = threadIdx.x * x128::size; c < C; c += WARP_SIZE * x128::size) { + const x128 res = s_res[c / x128::size]; + for(int k = 0; k < x128::size; ++k) { + v += (float)res[k] * (float)res[k]; + } + } + + v = warpReduceSum(v) / C; + float s = rsqrtf(v + eps); + + for(int c = threadIdx.x * x128::size; c < C; c += WARP_SIZE * x128::size) { + const x128 res = s_res[c / x128::size]; + const x128 w = s_weight[c / x128::size]; + x128 out; + for(int k = 0; k < x128::size; ++k) { + float n = s * (float)res[k]; // normalized output + float o = n * (float)w[k]; // scale + out[k] = o; + } + + store128cs(normed + c, out); + } + // cache the rrms for the backward pass later + if(threadIdx.x == 0) { + rrms[idx] = s; + } +} + + +__global__ void __launch_bounds__(512, 2) // todo - any warnings on Turing with only 1024 threads? + rmsnorm_backward_kernel10(floatX* dinp, floatX* dweight, float* scratch, + const floatX* dout, const floatX* inp, const floatX* weight, + const float* rstd, int B, int T, int C) { + // TODO: this kernel uses too much shared memory due to historical reasons of it coming from layernorm_backward.cu + // this memory use can be reduced by half later + int BLOCK_SIZE = blockDim.x; + int warpsInBlock = BLOCK_SIZE / WARP_SIZE; //number of warps in block + extern __shared__ float shared[]; + + int warpId = threadIdx.x / WARP_SIZE; // warp index within a block + int baseIdx = blockIdx.x * warpsInBlock + warpId; + int warpThreadIdx = threadIdx.x % WARP_SIZE; // Thread index within the warp + int warpsInGrid = gridDim.x * warpsInBlock; + int C_per_iteration = WARP_SIZE * x128::size; + int iterations_C = CEIL_DIV(C, C_per_iteration); // + 2; + + // the first half of shared memory is bias, second is weight + size_t rounded_C = CEIL_DIV(C, (32 * x128::size)) * (32 * x128::size); + float* dweight_shared = shared + rounded_C; + // warp zero doesn't actually write to the _tmp_shared memory locations, so we don't need to reserve memory + // the obvious solution is to change the addressing below to use (threadId.x-32) as offset, but that causes + // register spills, so instead we mess with the base pointer here, which doesn't increase register usage. + float* dweight_tmp_shared = shared + 2 * rounded_C + f128::size * BLOCK_SIZE - 2 * WARP_SIZE * f128::size; + + // init shared memory to zero + for(int i = threadIdx.x * f128::size; i < rounded_C; i += BLOCK_SIZE * f128::size) { + store128(dweight_shared + i, f128::zeros()); + } + __syncthreads(); + + for (int bt = baseIdx; bt < B * T; bt += warpsInGrid) { + const floatX* dout_bt = dout + bt * C; + const floatX* inp_bt = inp +bt * C; + floatX* dinp_bt = dinp + bt * C; + + // first: two reduce operations + float dnorm_mean = 0.0f; + float dnorm_norm_mean = 0.0f; + for (int i = warpThreadIdx * x128::size; i < C; i += WARP_SIZE * x128::size) { + x128 dout128_i = load128(dout_bt + i); + x128 inp128_i = load128(inp_bt + i); + x128 weight128_i = load128(weight + i); + for (int k = 0; k < x128::size; k++) { + float dnorm_i = (float)weight128_i[k] * (float)dout128_i[k]; + dnorm_mean += dnorm_i; + dnorm_norm_mean += dnorm_i * (float)inp128_i[k]; + } + } + + const float rstd_bt = rstd[bt]; + dnorm_norm_mean = warpReduceSum(dnorm_norm_mean) / C * rstd_bt; + + for (int c = 0; c < iterations_C; c++) { + int global_index = (warpThreadIdx * x128::size) + (c * C_per_iteration); + + x128 dout128 = x128::zeros(); + x128 inp128 = x128::zeros(); + x128 dinp128 = x128::zeros(); + x128 weight128 = x128::zeros(); + + if(global_index < C) { + dout128 = load128cs(dout_bt + global_index); + inp128 = load128cs(inp_bt + global_index); + dinp128 = load128(dinp_bt + global_index); + weight128 = load128(weight + global_index); + } + + for(int o = 0; o < x128::size / f128::size; ++o) { + f128 dweight_f; + for(int i = 0; i < f128::size; ++i) { + int x = o * f128::size + i; + float dout_i = (float)dout128[x]; + float norm_bti = ((float)inp128[x]) * rstd_bt; + dweight_f[i] = norm_bti * dout_i; + + float dval = 0.0f; + dval += (float) weight128[x] * (float)dout128[x]; // term 1 + dval -= norm_bti * dnorm_norm_mean; // term 2 + dval *= rstd_bt; // final scale + dinp128[x] = (floatX) ((float) dinp128[x] + dval); + } + + if (warpId != 0) { + // this seems to generate a 64-bit store, instead of 128-bit. + // however, forcing 128-bit (e.g., using inline ptx), results in register + // spilling and much worse performance, so we'll keep it like this for now + // but ideally, we could reduce the register pressure a little. + store128(dweight_tmp_shared + threadIdx.x * f128::size, dweight_f); + } + __syncthreads(); + if (warpId == 0) { + for (int j = 1; j < warpsInBlock; j++) { + f128 dweight_tmp = load128(dweight_tmp_shared + f128::size * (threadIdx.x + j * WARP_SIZE)); + for(int i = 0; i < f128::size; ++i) { + dweight_f[i] += dweight_tmp[i]; + } + } + } + __syncthreads(); + if (warpId == 0) { + f128 dw_old = load128(dweight_shared + global_index + f128::size * o); + for(int i = 0; i < f128::size; ++i) { + dweight_f[i] += dw_old[i]; + } + store128(dweight_shared + global_index + f128::size * o, dweight_f); + } + } + if(global_index < C) { + // cache in L2 as this is read by the next kernel, but bypass L1 to minimise thrashing + store128cg(dinp_bt + global_index, dinp128); + } + } + } + __syncthreads(); + // Each block writes its partial sum to global memory + // The last block to finish becomes responsible for summing up all the partial sums + // This is done by atomically incrementing a flag (cleared to 0 before launching the kernel) + unsigned int* scratchFlag = (unsigned int*)(scratch); + // Increment scratch pointer by a full cacheline so that everything remains cacheline aligned + scratch += 32; + float* scratch_dweight = scratch + C; + for(int i = threadIdx.x * f128::size; i < C; i += BLOCK_SIZE * f128::size) { + // Write to global memory in the same "shared memory banking friendly" order + store128(scratch_dweight + i + 2*C*blockIdx.x, load128(dweight_shared + i)); + } + __syncthreads(); + // that portion of shared memory is no longer used, so we can repurpose it for the scratch flag. + unsigned int *tmp_flag = (unsigned int*)(shared + 2*rounded_C); + if (threadIdx.x == 0) { + *tmp_flag = atomicInc(scratchFlag, gridDim.x); + } + __syncthreads(); + if (*tmp_flag == gridDim.x-1) { + // Reduction of the partial sums by the final block + // todo - there isn't enough parallelism even inside that single SM... + // ==> so could maybe split into another kernel with YET ANOTHER level of reduction?! + for(int i = threadIdx.x * f128::size; i < C; i += BLOCK_SIZE * f128::size) { + f128 dweight_accum = f128::zeros(); + + for (int read_block_idx = 0; read_block_idx < gridDim.x; read_block_idx++) { + int offset = i + 2*C*read_block_idx; + f128 dweight128 = load128(scratch_dweight + offset); + for(int k = 0; k < f128::size; k++) { + dweight_accum[k] += dweight128[k]; + } + } + store128(dweight_shared + i, dweight_accum); + } + __syncthreads(); + + // convert from float/FP32 to floatX/BF16 for the final write + // this is separate because it cannot use as many warps as the above (f128 vs x128) + // todo - if we split this code into another kernel, we could maybe do it at the same time? + for (int c = warpId; c < iterations_C; c += warpsInBlock) { + int global_index = (warpThreadIdx * x128::size) + (c * C_per_iteration); + if (global_index >= C) { + break; + } + x128 dweight128 = load128(dweight + global_index); + for(int o = 0; o < x128::size / f128::size; ++o) { + f128 s_dw = load128(dweight_shared + global_index + o * f128::size); + for(int i = 0; i < f128::size; ++i) { + int x = o * f128::size + i; + dweight128[x] = (floatX)(s_dw[i] + (float)dweight128[x]); + } + } + store128(dweight + global_index, dweight128); + } + } +} + +// ---------------------------------------------------------------------------- +// Kernel launchers + +void rmsnorm_forward(floatX* out, float* rms, + floatX* inp, const floatX* weight, + int B, int T, int C, cudaStream_t stream) { + NVTX_RANGE_FN(); + const int block_size = 256; + int block_y = block_size / WARP_SIZE; + const int N = B * T; + const int grid_size = CEIL_DIV(N, block_y); + size_t smem = (1 + block_y) * C * sizeof(floatX); + + // in order to use more than 48 KiB of smem, need to call cudaFuncSetAttribute + // this may fail, in which case we fall back to the smem free implementation. + cudaCheck(cudaGetLastError()); + auto status = cudaFuncSetAttribute(rmsnorm_forward_kernel6, cudaFuncAttributeMaxDynamicSharedMemorySize, smem); + cudaCheck(cudaGetLastError()); + if (status == cudaSuccess) { + rmsnorm_forward_kernel6<<>>(out, rms, inp, weight, N, C); + } else { + // We should not allow for these perf regressions for now - just throw an error + assert(false); + } + cudaCheck(cudaGetLastError()); +} + +void fused_residual_rmsnorm_forward5(floatX* residual, floatX* normed, float* rrms, + const floatX* inp1, const floatX* inp2, + const floatX* weight, + int N, int C, cudaStream_t stream) { + // same as forward kernel but has a fused residual connection + const int block_size = 256; + int block_y = block_size / WARP_SIZE; + const int grid_size = CEIL_DIV(N, block_y); + size_t smem = (1 + block_y) * C * sizeof(floatX); + + // in order to use more than 48 KiB of smem, need to call cudaFuncSetAttribute + // this may fail, in which case we fall back to the smem free implementation. + cudaCheck(cudaGetLastError()); + auto status = cudaFuncSetAttribute(fused_residual_rmsnorm_forward_kernel5, cudaFuncAttributeMaxDynamicSharedMemorySize, smem); + cudaCheck(cudaGetLastError()); + if(status == cudaSuccess) { + fused_residual_rmsnorm_forward_kernel5<<>>(residual, normed, + rrms, inp1, inp2, + weight, N, C); + } else { + assert(false); + } + cudaCheck(cudaGetLastError()); +} + +void rmsnorm_backward(floatX* dinp, floatX* dweight, float* scratch, + const floatX* dout, const floatX* inp, const floatX* weight, const float* rstd, + int B, int T, int C, cudaStream_t stream) { + const int block_size = 512; + const int blocks_per_sm = 2; // supported on every architecture and less cache thrashing than 3 + const int grid_size = blocks_per_sm * deviceProp.multiProcessorCount; + size_t rounded_C = CEIL_DIV(C, (32 * x128::size)) * (32 * x128::size); + size_t shared_mem_size = (2 * rounded_C + 2 * (block_size - 32) * f128::size) * sizeof(float); + + cudaCheck(cudaMemsetAsync(scratch, 0, 1 * sizeof(float), stream)); // only need to reset the flag to 0 + rmsnorm_backward_kernel10<<>>(dinp, dweight, scratch, dout, inp, weight, rstd, B, T, C); + cudaCheck(cudaGetLastError()); +} diff --git a/train_llama3.cu b/train_llama3.cu index e515bb185..6dbae50e7 100644 --- a/train_llama3.cu +++ b/train_llama3.cu @@ -46,8 +46,9 @@ GPT-2 Transformer Neural Net training loop. See README.md for usage. // defines: encoder_forward, encoder_backward #include "llmc/encoder.cuh" // defines: layernorm_forward, residual_forward, fused_residual_forward5, layernorm_backward -// defines: rmsnorm_forward, fused_residual_rmsnorm_forward5 #include "llmc/layernorm.cuh" +// defines: rmsnorm_forward, fused_residual_rmsnorm_forward5, rmsnorm_backward +#include "llmc/rmsnorm.cuh" // defines: matmul_cublaslt, matmul_forward, matmul_backward, gelu_forward, gelu_backward_inplace #include "llmc/matmul.cuh" #ifdef ENABLE_CUDNN @@ -796,26 +797,6 @@ void gpt2_backward_and_reduce(GPT2 *model, int* inputs, const int* targets, int cudaCheck(cudaMemcpy(model->targets, targets, B * T * sizeof(int), cudaMemcpyHostToDevice)); tokenCheck(targets, B*T, V); fused_classifier(acts.output, acts.losses, dloss, model->targets, B, T, V, Vp, True, main_stream); - - // ------------------------------------------------------------------------ - // DEBUGGING: we only work until this point right now, so exit here - // transfer the first 32 elements to CPU and print them - float* output = acts.losses; - floatX* cpu = (floatX*)mallocCheck(32 * sizeof(floatX)); - cudaCheck(cudaMemcpy(cpu, output, 32 * sizeof(floatX), cudaMemcpyDeviceToHost)); - for (int i = 0; i < 32; i++) { - printf("q[%d] = %.8f\n", i, (float) cpu[i]); - } - // write to .bin file - // move output to cpu - floatX* cpu_output = (floatX*)mallocCheck(B*T * sizeof(floatX)); - cudaCheck(cudaMemcpy(cpu_output, output, B*T * sizeof(floatX), cudaMemcpyDeviceToHost)); - FILE* f = fopen("out.bin", "wb"); - fwrite(cpu_output, sizeof(floatX), B*T, f); - fclose(f); - exit(0); - // ------------------------------------------------------------------------ - // ------------------------------------------------------------------------ // backward pass: go in the reverse order of the forward pass, and call backward() functions @@ -832,10 +813,29 @@ void gpt2_backward_and_reduce(GPT2 *model, int* inputs, const int* targets, int // technically that is a small, inline backward() pass of calculating // total, final loss as the mean over all losses over all (B,T) positions in the batch // next: backward the classifier matmul - matmul_backward(model->acts.scratch_bt4c, grads.wte, NULL, acts.output, acts.lnf, params.wte, NULL, B, T, C, Vp, main_stream); + matmul_backward(model->acts.scratch_bt4c, grads.wpe, NULL, acts.output, acts.lnf, params.wpe, NULL, B, T, C, Vp, main_stream); // backward the final layernorm floatX* residual = acts.residual3 + (L-1) * B * T * C; // last residual is in residual3 - layernorm_backward(dresidual, grads.lnfw, grads.lnfb, scratchF, model->acts.scratch_bt4c, residual, params.lnfw, acts.lnf_mean, acts.lnf_rstd, B, T, C, main_stream); + rmsnorm_backward(dresidual, grads.lnfw, scratchF, model->acts.scratch_bt4c, residual, params.lnfw, acts.lnf_rstd, B, T, C, main_stream); + + // ------------------------------------------------------------------------ + // DEBUGGING: we only work until this point right now, so exit here + // transfer the first 32 elements to CPU and print them + float* output = (float*)dresidual; + floatX* cpu = (floatX*)mallocCheck(32 * sizeof(floatX)); + cudaCheck(cudaMemcpy(cpu, output, 32 * sizeof(floatX), cudaMemcpyDeviceToHost)); + for (int i = 0; i < 32; i++) { + printf("q[%d] = %.8f\n", i, (float) cpu[i]); + } + // write to .bin file + // move output to cpu + floatX* cpu_output = (floatX*)mallocCheck(B*T*C * sizeof(floatX)); + cudaCheck(cudaMemcpy(cpu_output, output, B*T*C * sizeof(floatX), cudaMemcpyDeviceToHost)); + FILE* f = fopen("out.bin", "wb"); + fwrite(cpu_output, sizeof(floatX), B*T*C, f); + fclose(f); + exit(0); + // ------------------------------------------------------------------------ // from this point on, we no longer need the values stored in the last residual, so we can reuse that memory as generic // scratch for backward computations diff --git a/train_llama3.py b/train_llama3.py index d2b6c0c38..b87e6f414 100644 --- a/train_llama3.py +++ b/train_llama3.py @@ -301,7 +301,10 @@ def forward(self, idx, targets=None, return_logits=True, start_pos=0): for i, block in enumerate(self.transformer.h): x = block(x, freqs_cis, start_pos, mask) - x = self.transformer.ln_f(x) + + self.DEBUG_INPUT = x.detach() + self.DEBUG_INPUT.requires_grad = True + x = self.transformer.ln_f(self.DEBUG_INPUT) if targets is not None: # if we are given some desired targets also calculate the loss @@ -312,17 +315,6 @@ def forward(self, idx, targets=None, return_logits=True, start_pos=0): logits = self.lm_head(x[:, [-1], :]).float() # note: using list [-1] to preserve the time dim loss = None - # --------------------------------------------------------------------- - # DEBUGGING: print first 32 elements of x - x = F.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1), ignore_index=-1, reduction='none') - for i in range(32): - print("q[{}]: {:.8f}".format(i, x.view(-1)[i].item())) - # write to .bin file - with open("ref.bin", "wb") as f: - f.write(x.view(-1).cpu().detach().numpy().tobytes()) - breakpoint() - # --------------------------------------------------------------------- - # there are performance reasons why not returning logits is prudent, if not needed if not return_logits: logits = None @@ -1264,6 +1256,18 @@ def get_lr(it): # backward pass if not args.inference_only: loss.backward() + + # --------------------------------------------------------------------- + # DEBUGGING: print first 32 elements of x + x = model.DEBUG_INPUT.grad + for i in range(32): + print("q[{}]: {:.8f}".format(i, x.view(-1)[i].item())) + # write to .bin file + with open("ref.bin", "wb") as f: + f.write(x.view(-1).cpu().detach().numpy().tobytes()) + breakpoint() + # --------------------------------------------------------------------- + if ddp: dist.all_reduce(lossf, op=dist.ReduceOp.AVG) lossf = lossf.item() From 01c2895e173926c3e9022a0ea6a65c3a4402980b Mon Sep 17 00:00:00 2001 From: Insop Song Date: Thu, 26 Sep 2024 16:34:50 -0700 Subject: [PATCH 30/63] Update RoPE naming --- train_llama3.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/train_llama3.py b/train_llama3.py index b87e6f414..5f828ae67 100644 --- a/train_llama3.py +++ b/train_llama3.py @@ -3,7 +3,7 @@ Will save the model weights into files, to be read from C as initialization. This code differs from GPT-2 very slightly, there are three main differences: -1) RoPE: LLaMA uses a different positional encoding scheme called Relative Positional Encoding (RoPE). +1) RoPE: LLaMA uses a different positional encoding scheme called Rotary Position Embedding (RoPE). 2) GQA: Grouped Query Attention (GQA) is used to reduce the number of attention heads. 3) SwiGLU: Swish-Gated Linear Unit (SwiGLU) is used as the activation function in the MLP. From 1b54612c46e0970082781c22a07526564796f32d Mon Sep 17 00:00:00 2001 From: Andrej Karpathy Date: Fri, 27 Sep 2024 00:21:43 +0000 Subject: [PATCH 31/63] i can backward through MLP block. Attention block is next --- llmc/swiglu.cuh | 46 +++++++++++++++++++++++++ train_llama3.cu | 92 ++++++++++++++++++++++++++++--------------------- train_llama3.py | 13 +++---- 3 files changed, 106 insertions(+), 45 deletions(-) diff --git a/llmc/swiglu.cuh b/llmc/swiglu.cuh index a574233d3..4eda4c3f0 100644 --- a/llmc/swiglu.cuh +++ b/llmc/swiglu.cuh @@ -63,6 +63,41 @@ __global__ void swiglu_forward_kernel2(floatX* out, const floatX* inp, int B, in out[idx] = (floatX)((x1 * x2) / (1.0f + expf(-x2))); } +__global__ void swiglu_backward_kernel1(floatX* dinp, const floatX* dout, const floatX* inp, int B, int T, int C) { + int idx = (blockIdx.x * blockDim.x + threadIdx.x) * x128::size; + const floatX* dout_ptr = dout + idx; + // b,t,c in the output + int b = idx / (T * C); + int t = (idx / C) % T; + int c = idx % C; + // coords in input + int C2 = C * 2; + const floatX* inp1_ptr = inp + (b * T * C2 + t * C2 + c); + const floatX* inp2_ptr = inp1_ptr + C; + floatX* dinp1_ptr = dinp + (b * T * C2 + t * C2 + c); + floatX* dinp2_ptr = dinp1_ptr + C; + // backward + x128 dinp1; + x128 dinp2; + x128 packed_dout = load128cs(dout_ptr); + x128 packed_inp1 = load128cs(inp1_ptr); // fc1 + x128 packed_inp2 = load128cs(inp2_ptr); // fc2 + for(int k = 0; k < packed_inp1.size; ++k) { + float x1 = (float)packed_inp1[k]; + float x2 = (float)packed_inp2[k]; + float dout = (float)packed_dout[k]; + + float sx2 = 1.0f / (1.0f + expf(-x2)); // sigmoid of x2 + float dx1 = dout * x2 * sx2; + float dx2 = dout * x1 * sx2 * (1.0f + x2 * (1.0f - sx2)); + + dinp1[k] = (floatX)dx1; + dinp2[k] = (floatX)dx2; + } + store128(dinp1_ptr, dinp1); + store128(dinp2_ptr, dinp2); +} + // ---------------------------------------------------------------------------- // kernel launchers @@ -84,3 +119,14 @@ void swiglu_forward_naive(floatX* out, const floatX* inp, int B, int T, int C, c swiglu_forward_kernel2<<>>(out, inp, B, T, C); cudaCheck(cudaGetLastError()); } + +void swiglu_backward(floatX* dinp, const floatX* dout, const floatX* inp, int B, int T, int C, cudaStream_t stream) { + // input is (B, T, 2C), output is (B, T, C) + // we have that inp[b, t, :] = [fc1, fc2] (i.e. they are concatenated in each C-fiber) + NVTX_RANGE_FN(); + const int block_size = 128; + assert((B*T*C) % (block_size * x128::size) == 0); + const int grid_size = CEIL_DIV(B*T*C, block_size * x128::size); + swiglu_backward_kernel1<<>>(dinp, dout, inp, B, T, C); + cudaCheck(cudaGetLastError()); +} diff --git a/train_llama3.cu b/train_llama3.cu index 6dbae50e7..630342ab7 100644 --- a/train_llama3.cu +++ b/train_llama3.cu @@ -68,7 +68,7 @@ GPT-2 Transformer Neural Net training loop. See README.md for usage. #include "llmc/repkv.cuh" // defines: precompute_freqs_cis, rope_forward #include "llmc/rope.cuh" -// defines: swiglu_forward +// defines: swiglu_forward, swiglu_backward #include "llmc/swiglu.cuh" // ----------- Multi-GPU support ----------- // defines: ncclFloatX, ncclCheck, MultiGpuConfig, ShardInfo @@ -289,8 +289,8 @@ void fill_in_activation_sizes(const ActivationTensors* data, TensorSpec (&tensor tensors[15] = TENSOR_SPEC(data->lnf_rstd, B * T); tensors[16] = TENSOR_SPEC(data->losses, B * T); tensors[17] = TENSOR_SPEC(data->qkvr, L * B * T * qkv_channels); - tensors[18] = TENSOR_SPEC(data->output, B * T * max(qkv_channels, max(NH*T, Vp))); - tensors[19] = TENSOR_SPEC(data->scratch_bt4c, B * T * 4 * C); + tensors[18] = TENSOR_SPEC(data->output, B * T * max(qkv_channels, max(ffn_channels, max(NH*T, Vp)))); + tensors[19] = TENSOR_SPEC(data->scratch_bt4c, B * T * ffn_channels); tensors[20] = TENSOR_SPEC(data->scratch_btc, B * T * C); } @@ -786,6 +786,16 @@ void gpt2_backward_and_reduce(GPT2 *model, int* inputs, const int* targets, int const size_t L = model->config.num_layers; const size_t NH = model->config.num_heads; const size_t C = model->config.channels; + const size_t n_head = model->config.num_heads; + const size_t n_kv_head = model->config.num_kv_heads; + const size_t hd = C / n_head; // head dimension + const size_t qkv_channels = (n_head + 2*n_kv_head) * hd; // Q, K, V channels + size_t hidden_dim = 4 * C; + hidden_dim = (2 * hidden_dim) / 3; + hidden_dim = model->config.ffn_dim_multiplier * hidden_dim; + hidden_dim = model->config.multiple_of * ((hidden_dim + model->config.multiple_of - 1) / model->config.multiple_of); + size_t ffn_channels = hidden_dim * 2; // c_fc + c_fc2 concatenated + size_t ffn_channels_post_gelu = hidden_dim; // swiglu halves the channels ParameterTensors params = model->params; // for brevity ParameterTensors grads = model->grads; @@ -817,26 +827,6 @@ void gpt2_backward_and_reduce(GPT2 *model, int* inputs, const int* targets, int // backward the final layernorm floatX* residual = acts.residual3 + (L-1) * B * T * C; // last residual is in residual3 rmsnorm_backward(dresidual, grads.lnfw, scratchF, model->acts.scratch_bt4c, residual, params.lnfw, acts.lnf_rstd, B, T, C, main_stream); - - // ------------------------------------------------------------------------ - // DEBUGGING: we only work until this point right now, so exit here - // transfer the first 32 elements to CPU and print them - float* output = (float*)dresidual; - floatX* cpu = (floatX*)mallocCheck(32 * sizeof(floatX)); - cudaCheck(cudaMemcpy(cpu, output, 32 * sizeof(floatX), cudaMemcpyDeviceToHost)); - for (int i = 0; i < 32; i++) { - printf("q[%d] = %.8f\n", i, (float) cpu[i]); - } - // write to .bin file - // move output to cpu - floatX* cpu_output = (floatX*)mallocCheck(B*T*C * sizeof(floatX)); - cudaCheck(cudaMemcpy(cpu_output, output, B*T*C * sizeof(floatX), cudaMemcpyDeviceToHost)); - FILE* f = fopen("out.bin", "wb"); - fwrite(cpu_output, sizeof(floatX), B*T*C, f); - fclose(f); - exit(0); - // ------------------------------------------------------------------------ - // from this point on, we no longer need the values stored in the last residual, so we can reuse that memory as generic // scratch for backward computations floatX* dl_btc = residual; @@ -850,37 +840,36 @@ void gpt2_backward_and_reduce(GPT2 *model, int* inputs, const int* targets, int // get the pointers of the weights for this layer floatX* l_ln1w = params.ln1w + l * C; floatX* l_ln1b = params.ln1b + l * C; - floatX* l_qkvw = params.qkvw + l * 3*C * C; + floatX* l_qkvw = params.qkvw + l * qkv_channels * C; floatX* l_attprojw = params.attprojw + l * C * C; floatX* l_ln2w = params.ln2w + l * C; - floatX* l_ln2b = params.ln2b + l * C; - floatX* l_fcw = params.fcw + l * 4*C * C; - floatX* l_fcprojw = params.fcprojw + l * C * 4*C; + floatX* l_fcw = params.fcw + l * ffn_channels * C; + floatX* l_fcprojw = params.fcprojw + l * C * ffn_channels_post_gelu; // get the pointers of the gradients of the weights for this layer floatX* dl_ln1w = grads.ln1w + l * C; floatX* dl_ln1b = grads.ln1b + l * C; - floatX* dl_qkvw = grads.qkvw + l * 3*C * C; - floatX* dl_qkvb = grads.qkvb + l * 3*C; + floatX* dl_qkvw = grads.qkvw + l * qkv_channels * C; + floatX* dl_qkvb = grads.qkvb + l * qkv_channels; floatX* dl_attprojw = grads.attprojw + l * C * C; floatX* dl_attprojb = grads.attprojb + l * C; floatX* dl_ln2w = grads.ln2w + l * C; floatX* dl_ln2b = grads.ln2b + l * C; - floatX* dl_fcw = grads.fcw + l * 4*C * C; - floatX* dl_fcb = grads.fcb + l * 4*C; - floatX* dl_fcprojw = grads.fcprojw + l * C * 4*C; + floatX* dl_fcw = grads.fcw + l * ffn_channels * C; + floatX* dl_fcb = grads.fcb + l * ffn_channels; + floatX* dl_fcprojw = grads.fcprojw + l * C * ffn_channels_post_gelu; floatX* dl_fcprojb = grads.fcprojb + l * C; // get the pointers of the activations for this layer floatX* l_ln1 = (model->recompute < 2) ? acts.ln1 + l * B * T * C : acts.lnf; float* l_ln1_mean = acts.ln1_mean + l * B * T; float* l_ln1_rstd = acts.ln1_rstd + l * B * T; - floatX* l_qkvr = acts.qkvr + l * B * T * 3*C; + floatX* l_qkvr = acts.qkvr + l * B * T * qkv_channels; floatX* l_atty = acts.atty + l * B * T * C; floatX* l_residual2 = acts.residual2 + l * B * T * C; floatX* l_ln2 = (model->recompute < 2) ? acts.ln2 + l * B * T * C : acts.lnf; float* l_ln2_mean = acts.ln2_mean + l * B * T; float* l_ln2_rstd = acts.ln2_rstd + l * B * T; - floatX* l_fch_pre_gelu = acts.fch + l * B * T * 4*C; - floatX* l_fch_gelu = (model->recompute < 1) ? acts.fch_gelu + l * B * T * 4*C : acts.fch_gelu; + floatX* l_fch_pre_gelu = acts.fch + l * B * T * ffn_channels; + floatX* l_fch_gelu = (model->recompute < 1) ? acts.fch_gelu + l * B * T * ffn_channels_post_gelu : acts.fch_gelu; // get the pointers of the gradients of the activations for this layer // notice that there is no l *, because we just have a single copy, and keep // re-using this memory in every Transformer block as we calculate backward pass @@ -891,14 +880,39 @@ void gpt2_backward_and_reduce(GPT2 *model, int* inputs, const int* targets, int if(model->recompute >= 1) { // recompute >= 1 means we recompute gelu. in this case, // l_fch_gelu is just a buffer, so re-compute the gelu from l_fch here - gelu_forward(l_fch_gelu, l_fch_pre_gelu, B*T*4*C, main_stream); + // gelu_forward(l_fch_gelu, l_fch_pre_gelu, B*T*4*C, main_stream); + swiglu_forward(l_fch_gelu, l_fch_pre_gelu, B, T, ffn_channels_post_gelu, main_stream); } - matmul_backward(dl_bt4c, dl_fcprojw, dl_fcprojb, dresidual, l_fch_gelu, l_fcprojw, scratchF, B, T, 4*C, C, main_stream, l_fch_pre_gelu, model->gelu_fusion); + // backward the 2nd matmul of MLP + matmul_backward(dl_bt4c, dl_fcprojw, dl_fcprojb, dresidual, l_fch_gelu, l_fcprojw, scratchF, B, T, ffn_channels_post_gelu, C, main_stream); + // backward the swiglu here, use scratchX to hold the grad because SwiGLU can't be inplace + swiglu_backward(scratchX, dl_bt4c, l_fch_pre_gelu, B, T, ffn_channels_post_gelu, main_stream); + // backward the 1st matmul of MLP if(model->recompute >= 2) { // same as gelu above, l_ln1 and l_ln2 are just buffers if recompute >= 2, recompute them here on demand - layernorm_forward(l_ln2, l_ln2_mean, l_ln2_rstd, l_residual2, l_ln2w, l_ln2b, B, T, C, main_stream); + rmsnorm_forward(l_ln2, l_ln2_rstd, l_residual2, l_ln2w, B, T, C, main_stream); } - matmul_backward(dl_btc, dl_fcw, dl_fcb, dl_bt4c, l_ln2, l_fcw, scratchF, B, T, C, 4 * C, main_stream); + matmul_backward(dl_btc, dl_fcw, dl_fcb, scratchX, l_ln2, l_fcw, scratchF, B, T, C, ffn_channels, main_stream); + + // ------------------------------------------------------------------------ + // DEBUGGING: we only work until this point right now, so exit here + // transfer the first 32 elements to CPU and print them + float* output = (float*)dl_btc; + floatX* cpu = (floatX*)mallocCheck(32 * sizeof(floatX)); + cudaCheck(cudaMemcpy(cpu, output, 32 * sizeof(floatX), cudaMemcpyDeviceToHost)); + for (int i = 0; i < 32; i++) { + printf("q[%d] = %.8f\n", i, (float) cpu[i]); + } + // write to .bin file + // move output to cpu + floatX* cpu_output = (floatX*)mallocCheck(B*T*C * sizeof(floatX)); + cudaCheck(cudaMemcpy(cpu_output, output, B*T*C * sizeof(floatX), cudaMemcpyDeviceToHost)); + FILE* f = fopen("out.bin", "wb"); + fwrite(cpu_output, sizeof(floatX), B*T*C, f); + fclose(f); + exit(0); + // ------------------------------------------------------------------------ + // layernorm backward does += to the dresidual, so it correctly accumulates grad from the MLP block above layernorm_backward(dresidual, dl_ln2w, dl_ln2b, scratchF, dl_btc, l_residual2, l_ln2w, l_ln2_mean, l_ln2_rstd, B, T, C, main_stream); matmul_backward(dl_btc, dl_attprojw, dl_attprojb, dresidual, l_atty, l_attprojw, scratchF, B, T, C, C, main_stream); diff --git a/train_llama3.py b/train_llama3.py index b87e6f414..0e1a35ac3 100644 --- a/train_llama3.py +++ b/train_llama3.py @@ -234,7 +234,11 @@ def __init__(self, config): def forward(self, x, freqs_cis=None, start_pos=None, mask=None): x = x + self.attn(self.ln_1(x), freqs_cis, start_pos, mask) - x = x + self.mlp(self.ln_2(x)) + MLP_INPUT = self.ln_2(x) + MLP_INPUT = MLP_INPUT.detach() + MLP_INPUT.requires_grad = True + self.MLP_INPUT = MLP_INPUT + x = x + self.mlp(MLP_INPUT) return x # ----------------------------------------------------------------------------- @@ -301,10 +305,7 @@ def forward(self, idx, targets=None, return_logits=True, start_pos=0): for i, block in enumerate(self.transformer.h): x = block(x, freqs_cis, start_pos, mask) - - self.DEBUG_INPUT = x.detach() - self.DEBUG_INPUT.requires_grad = True - x = self.transformer.ln_f(self.DEBUG_INPUT) + x = self.transformer.ln_f(x) if targets is not None: # if we are given some desired targets also calculate the loss @@ -1259,7 +1260,7 @@ def get_lr(it): # --------------------------------------------------------------------- # DEBUGGING: print first 32 elements of x - x = model.DEBUG_INPUT.grad + x = model.transformer.h[-1].MLP_INPUT.grad for i in range(32): print("q[{}]: {:.8f}".format(i, x.view(-1)[i].item())) # write to .bin file From 28e4a7f83e9d948df289e460cdf224d3dea006bf Mon Sep 17 00:00:00 2001 From: Andrej Karpathy Date: Fri, 27 Sep 2024 01:13:32 +0000 Subject: [PATCH 32/63] small fixes, but still not too happy with this kernel, it wastes thread and more efficient implementation kernel2 is desireable and desired --- dev/cuda/repkv_backward.cu | 131 ++++++++----------------------------- 1 file changed, 27 insertions(+), 104 deletions(-) diff --git a/dev/cuda/repkv_backward.cu b/dev/cuda/repkv_backward.cu index 93f2720fa..9a00205d9 100644 --- a/dev/cuda/repkv_backward.cu +++ b/dev/cuda/repkv_backward.cu @@ -1,20 +1,6 @@ /* -Layer that takes a QKV tensor of shape (B, T, C) and replicates the K,V -some number of times. For example, if B=4, T=64, C=6144, and we have that: -- head dimension (hd) is 128 channels -- query heads: 32 -- key heads: 8 -- value heads: 8 -- so number of heads = 32 + 8 + 8 = 48, each of 128 channels, total of 6144 channels -We want to replicate the key/value vectors 4X, so that we get: -32 + 32 + 32 = 96 query, key, value heads, each of 128 channels, total of 12288 channels -Each of these vectors should be replicated by simple copying/concat 4X times. - -Compile and run as: -make repkv_backward -./repkv_backward 1 - -block_size 128 seems fastest on H100 +See repkv.cu for details. This is the backward pass of repkv forward. +Block size 128 seems fastest on H100 */ #include @@ -24,21 +10,19 @@ block_size 128 seems fastest on H100 #include "common.h" // cpu reference code -void repkv_backward_cpu(float* dinp, const float* inp, const float* doutp, +void repkv_backward_cpu(float* dinp, const float* dout, const int B, const int T, const int Cout, const int hd, const int qh, const int kh, const int vh) { assert(Cout == (hd * (3 * qh))); assert(kh == vh); - int nrep = qh / kh; // number of times to replicate key/value vectors - int Cin = hd * (qh + kh + vh); // output channels for (int b = 0; b < B; b++) { for (int t = 0; t < T; t++) { - // seek to the input position doutp[b,t,:] - const float* x = doutp + b * T * Cout + t * Cout; + // seek to the input position dout[b,t,:] + const float* x = dout + b * T * Cout + t * Cout; // seek to the output position out[b,t,:] float* y = dinp + b * T * Cin + t * Cin; // copy all the query vectors, no changes @@ -66,17 +50,15 @@ void repkv_backward_cpu(float* dinp, const float* inp, const float* doutp, } // kernels -__global__ void repkv_backward_kernel1(floatX* dinp, - const floatX* inp, const floatX* doutp, +__global__ void repkv_backward_kernel1(floatX* dinp, const floatX* dout, int B, int N, int NH, int replicate_factor, int HD) { - // we have a single tensor doutp of shapae of (B, N 3 * NH * HD) + // we have a single tensor dout of shapae of (B, N 3 * NH * HD) // we want to reduce sum (for K and V) into (B, N, (NH + 2*(NH/replicate_factor)) * HD) int idx = blockIdx.x * blockDim.x + threadIdx.x; - if (idx >= B * N * 3 * NH * HD) { return;} - int doutp_idx = idx; // keep backup + int dout_idx = idx; // keep backup - // decode the doutp index + // decode the dout index int d = idx % HD; idx /= HD; int nh = idx % NH; @@ -91,12 +73,12 @@ __global__ void repkv_backward_kernel1(floatX* dinp, if (c == 0) { dinp_idx = b * N * nh_total * HD + n * nh_total * HD + 0 * NH * HD + nh * HD + d; - dinp[dinp_idx] = __ldcs(&doutp[doutp_idx]); + dinp[dinp_idx] = __ldcs(&dout[dout_idx]); } else if (c == 1) { if (nh % replicate_factor == 0) { float reduced_sum = 0; for (int i = 0; i < replicate_factor; i++) { - reduced_sum += __ldcs(&doutp[doutp_idx+HD*i]); + reduced_sum += __ldcs(&dout[dout_idx+HD*i]); } dinp_idx = b * N * nh_total * HD + n * nh_total * HD + 1 * NH * HD + (nh / replicate_factor) * HD + d; @@ -107,7 +89,7 @@ __global__ void repkv_backward_kernel1(floatX* dinp, if (nh % replicate_factor == 0) { float reduced_sum = 0; for (int i = 0; i < replicate_factor; i++) { - reduced_sum += __ldcs(&doutp[doutp_idx+HD*i]); + reduced_sum += __ldcs(&dout[dout_idx+HD*i]); } dinp_idx = b * N * nh_total * HD + n * nh_total * HD + (NH * HD + (NH / replicate_factor) * HD) + (nh / replicate_factor) * HD + d; dinp[dinp_idx] = reduced_sum; @@ -116,23 +98,23 @@ __global__ void repkv_backward_kernel1(floatX* dinp, } // kernel launchers -void repkv_backward1(floatX* dinp, const floatX* inp, const floatX* doutp, +void repkv_backward1(floatX* dinp, const floatX* dout, const int B, const int T, const int NH, const int NH_KV, const int d, int block_size) { int total_threads = B * T * (3 * NH) * d; int num_blocks = ceil_div(total_threads, block_size); int replicate_factor = NH / NH_KV; - repkv_backward_kernel1<<>>(dinp, inp, doutp, B, T, NH, replicate_factor, d); + repkv_backward_kernel1<<>>(dinp, dout, B, T, NH, replicate_factor, d); cudaCheck(cudaGetLastError()); } // kernel dispatcher void repkv_backward(int kernel_num, - floatX* dinp, const floatX* inp, const floatX* doutp, + floatX* dinp, const floatX* dout, int B, int T, int NH, int NH_KV, int d, int block_size) { switch (kernel_num) { case 1: - repkv_backward1(dinp, inp, doutp, B, T, NH, NH_KV, d, block_size); + repkv_backward1(dinp, dout, B, T, NH, NH_KV, d, block_size); break; default: printf("Invalid kernel number\n"); @@ -140,64 +122,16 @@ void repkv_backward(int kernel_num, } } -#ifdef DEBUG -static void log_mat(float *inp, int B, int T, int C, int hd, int qh, int kh, int vh, char *title) -{ - printf("%s -----\n", title); - for (int b = 0; b < B; b++) { - printf("batch : %d ", b); - for (int t = 0; t < T; t++) { - printf("t = %d\n", t); - const float* x = inp + b * T * C + t * C; - printf("Query\n"); - for (int h=0; h < qh; h++) { - for (int i = 0; i < hd; i++) { - printf("%f ", x[i]); - } - x += hd; // advance input pointer - printf("\n"); - } - printf("Key\n"); - for (int h=0; h < kh; h++) { - for (int i = 0; i < hd; i++) { - printf("%f ", x[i]); - } - x += hd; // advance input pointer - printf("\n"); - } - printf("Value\n"); - for (int h=0; h < vh; h++) { - for (int i = 0; i < hd; i++) { - printf("%f ", x[i]); - } - x += hd; // advance input pointer - printf("\n"); - } - } - } - printf("\n"); -} -#endif // DEBUG - // tester int main(int argc, char **argv) { srand(0); -#ifdef DEBUG - int B = 1; - int T = 2; - int hd = 3; // head dim - int qh = 4; // num query heads - int kh = 2; // num key heads - int vh = 2; // num value heads -#else int B = 8; int T = 1024; int hd = 128; // head dim int qh = 32; // num query heads int kh = 8; // num key heads int vh = 8; // num value heads -#endif int deviceIdx = 0; cudaCheck(cudaSetDevice(deviceIdx)); @@ -208,16 +142,15 @@ int main(int argc, char **argv) { // allocate (and fill) CPU memory float* dinp = (float*)malloc(B * T * Cin * sizeof(float)); memset(dinp, 0, B * T * Cin * sizeof(float)); - float* inp = make_random_float(B * T * Cin); - float* doutp = make_random_float(B * T * Cout * sizeof(float)); + float* dout = make_random_float(B * T * Cout * sizeof(float)); // allocate GPU memory float* d_dinp; float* d_inp; - float* d_doutp; + float* d_dout; cudaCheck(cudaMalloc(&d_dinp, B * T * Cin * sizeof(float))); cudaCheck(cudaMalloc(&d_inp, B * T * Cin * sizeof(float))); - cudaCheck(cudaMalloc(&d_doutp, B * T * Cout * sizeof(float))); + cudaCheck(cudaMalloc(&d_dout, B * T * Cout * sizeof(float))); // read kernel_num from command line int kernel_num = 1; @@ -226,26 +159,16 @@ int main(int argc, char **argv) { } printf("Using kernel %d\n", kernel_num); -#ifdef DEBUG - int nrep = qh/kh; - log_mat(doutp, B, T, Cout, hd, qh, nrep*kh, nrep*vh, "doutp"); -#endif // DEBUG - // CPU reference calculate - repkv_backward_cpu(dinp, inp, doutp, B, T, Cout, hd, qh, kh, vh); - -#ifdef DEBUG - log_mat(dinp, B, T, Cin, hd, qh, kh, vh, "dinp"); -#endif // DEBUG + repkv_backward_cpu(dinp, dout, B, T, Cout, hd, qh, kh, vh); // check the correctness of the kernel at all block sizes int block_sizes[] = {32, 64, 128, 256, 512, 1024}; - cudaCheck(cudaMemcpy(d_dinp, inp, B * T * Cin * sizeof(float), cudaMemcpyHostToDevice)); - cudaCheck(cudaMemcpy(d_doutp, doutp, B * T * Cout * sizeof(float), cudaMemcpyHostToDevice)); + cudaCheck(cudaMemcpy(d_dout, dout, B * T * Cout * sizeof(float), cudaMemcpyHostToDevice)); for (int j = 0; j < sizeof(block_sizes) / sizeof(int); j++) { int block_size = block_sizes[j]; printf("Checking block size %d.\n", block_size); - repkv_backward(kernel_num, d_dinp, d_inp, d_doutp, B, T, qh, kh, hd, block_size); + repkv_backward(kernel_num, d_dinp, d_dout, B, T, qh, kh, hd, block_size); validate_result(d_dinp, dinp, "out", B * T * Cin, 1e-5f); } printf("All results match. Starting benchmarks.\n\n"); @@ -255,17 +178,17 @@ int main(int argc, char **argv) { int block_size = block_sizes[j]; int repeat_times = 1000; float elapsed_time = benchmark_kernel(repeat_times, repkv_backward, kernel_num, - d_dinp, d_inp, d_doutp, B, T, qh, kh, hd, block_size); + d_dinp, d_dout, B, T, qh, kh, hd, block_size); printf("block_size %4d time %.4f ms\n", block_size, elapsed_time); } // free memory - free(inp); free(dinp); - free(doutp); - + free(dout); cudaCheck(cudaFree(d_dinp)); cudaCheck(cudaFree(d_inp)); - cudaCheck(cudaFree(d_doutp)); + cudaCheck(cudaFree(d_dout)); + + return 0; } From 075e430d23f8fa887bb290cb64b509a753c8c0b5 Mon Sep 17 00:00:00 2001 From: Andrej Karpathy Date: Fri, 27 Sep 2024 03:37:53 +0000 Subject: [PATCH 33/63] just pushing what i have. it's epsilon away from working sigh. basically at this point of where prints happen, gradients match. but once we backward attention, rope and repkv, gradients don't match. attention hasn't changed so that can't be wrong (?), so it's either repkv or rope. i have to go slower and double check the backward pass of both of these in detail. also had to introduce one more additional buffer for backward --- llmc/repkv.cuh | 58 +++++++++++++++++++++++++++++++++++++++++++++++++ llmc/rope.cuh | 41 +++++++++++++++++++++++++++++++--- train_llama3.cu | 48 +++++++++++++++++++++++----------------- train_llama3.py | 14 +++++++----- 4 files changed, 132 insertions(+), 29 deletions(-) diff --git a/llmc/repkv.cuh b/llmc/repkv.cuh index 666ad8c44..f4c517eaa 100644 --- a/llmc/repkv.cuh +++ b/llmc/repkv.cuh @@ -48,6 +48,54 @@ __global__ void repkv_forward_kernel1(floatX* replicated_qkv, replicated_qkv[idx_flat] = __ldcs(&gqa_qkv[inp_idx]); } +__global__ void repkv_backward_kernel1(floatX* dinp, const floatX* dout, + int B, int N, int NH, int replicate_factor, int HD) { + // we have a single tensor dout of shapae of (B, N 3 * NH * HD) + // we want to reduce sum (for K and V) into (B, N, (NH + 2*(NH/replicate_factor)) * HD) + int idx = blockIdx.x * blockDim.x + threadIdx.x; + if (idx >= B * N * 3 * NH * HD) { return;} + int dout_idx = idx; // keep backup + + // decode the dout index + int d = idx % HD; + idx /= HD; + int nh = idx % NH; + idx /= NH; + int c = idx % 3; + idx /= 3; + int n = idx % N; + int b = idx / N; + + int dinp_idx; + int nh_total = NH + 2 * (NH / replicate_factor); + + if (c == 0) { + dinp_idx = b * N * nh_total * HD + n * nh_total * HD + 0 * NH * HD + nh * HD + d; + dinp[dinp_idx] = __ldcs(&dout[dout_idx]); + } else if (c == 1) { + if (nh % replicate_factor == 0) { + float reduced_sum = 0; + for (int i = 0; i < replicate_factor; i++) { + reduced_sum += (float) __ldcs(&dout[dout_idx+HD*i]); + } + + dinp_idx = b * N * nh_total * HD + n * nh_total * HD + 1 * NH * HD + (nh / replicate_factor) * HD + d; + dinp[dinp_idx] = reduced_sum; + } + + } else { + if (nh % replicate_factor == 0) { + float reduced_sum = 0; + for (int i = 0; i < replicate_factor; i++) { + reduced_sum += (float) __ldcs(&dout[dout_idx+HD*i]); + } + dinp_idx = b * N * nh_total * HD + n * nh_total * HD + (NH * HD + (NH / replicate_factor) * HD) + (nh / replicate_factor) * HD + d; + dinp[dinp_idx] = reduced_sum; + } + } +} + +// kernel launchers void repkv_forward(floatX* out, const floatX* inp, int B, int T, int NH, int NH_KV, int HD, cudaStream_t stream) { // NH = number of query heads, NH_KV = number of key and value heads, HD = head dimension const int block_size = 128; @@ -61,3 +109,13 @@ void repkv_forward(floatX* out, const floatX* inp, int B, int T, int NH, int NH_ } cudaCheck(cudaGetLastError()); } + +void repkv_backward(floatX* dinp, const floatX* dout, + const int B, const int T, const int NH, const int NH_KV, const int d) { + const int block_size = 128; + int total_threads = B * T * (3 * NH) * d; + int num_blocks = CEIL_DIV(total_threads, block_size); + int replicate_factor = NH / NH_KV; + repkv_backward_kernel1<<>>(dinp, dout, B, T, NH, replicate_factor, d); + cudaCheck(cudaGetLastError()); +} diff --git a/llmc/rope.cuh b/llmc/rope.cuh index ca5fc56f9..50371c47b 100644 --- a/llmc/rope.cuh +++ b/llmc/rope.cuh @@ -58,18 +58,44 @@ __global__ void rope_forward_kernel1(floatX *out, const floatX *inp, const float int idx_bt = b * (T * 3 * n_head * head_dim) + t * (3 * n_head * head_dim); int idx_bth = idx_bt + qkv * (n_head * head_dim) + h * head_dim; int idxi = idx_bth + 2 * d; // index in the input - // fetch the input - float x_real = inp[idxi]; - float x_imag = inp[idxi + 1]; // fetch the freqs_cis int freqs_idx = t * head_dim + 2 * d; float freqs_cos = freqs_cis[freqs_idx]; float freqs_sin = freqs_cis[freqs_idx + 1]; + // fetch the input + float x_real = inp[idxi]; + float x_imag = inp[idxi + 1]; // apply the rotation out[idxi] = x_real * freqs_cos - x_imag * freqs_sin; out[idxi + 1] = x_real * freqs_sin + x_imag * freqs_cos; } +__global__ void rope_backward_inplace_kernel1(floatX *dinp, const floatX *dout, const floatX *freqs_cis, int B, int T, int n_head, int head_dim) { + int idx = blockIdx.x * blockDim.x + threadIdx.x; + int head_dim_half = head_dim / 2; + if (idx >= B * T * 3 * n_head * head_dim_half) return; + // decode the qkv index early so we can early exit if it's a value index + int qkv = (idx / (n_head * head_dim_half)) % 3; + if (qkv == 2) return; // no-op for v + // decode the individual indices and get the input index + int b = idx / (T * 3 * n_head * head_dim_half); + int t = (idx / (3 * n_head * head_dim_half)) % T; + int h = (idx / head_dim_half) % n_head; + int d = idx % head_dim_half; + int idx_bt = b * (T * 3 * n_head * head_dim) + t * (3 * n_head * head_dim); + int idx_bth = idx_bt + qkv * (n_head * head_dim) + h * head_dim; + int idxi = idx_bth + 2 * d; // index in the input + // fetch the freqs_cis + int freqs_idx = t * head_dim + 2 * d; + float freqs_cos = freqs_cis[freqs_idx]; + float freqs_sin = freqs_cis[freqs_idx + 1]; + // backward + float dout_real = (float)dout[idxi]; + float dout_imag = (float)dout[idxi + 1]; + dinp[idxi] = dout_real * freqs_cos + dout_imag * freqs_sin; + dinp[idxi + 1] = -dout_real * freqs_sin + dout_imag * freqs_cos; +} + void rope_forward(floatX *out, const floatX *inp, const floatX *freqs_cis, int B, int T, int n_head, int head_dim, cudaStream_t stream) { // the input and output to this kernel are (B, T, 3, NH, HD) where the 3 is q,k,v // we are going to launch exactly one thread per element of the output, @@ -81,3 +107,12 @@ void rope_forward(floatX *out, const floatX *inp, const floatX *freqs_cis, int B rope_forward_kernel1<<>>(out, inp, freqs_cis, B, T, n_head, head_dim); cudaCheck(cudaGetLastError()); } + +void rope_backward_inplace(floatX *dinp, const floatX *dout, const floatX *freqs_cis, int B, int T, int n_head, int head_dim, cudaStream_t stream) { + // backward pass of forward, mirrors the forward kernel in setup and indexing + const int block_size = 128; + int total_threads = B * T * 3 * n_head * head_dim / 2; + int num_blocks = CEIL_DIV(total_threads, block_size); + rope_backward_inplace_kernel1<<>>(dinp, dout, freqs_cis, B, T, n_head, head_dim); + cudaCheck(cudaGetLastError()); +} diff --git a/train_llama3.cu b/train_llama3.cu index 630342ab7..2cc554a0b 100644 --- a/train_llama3.cu +++ b/train_llama3.cu @@ -64,9 +64,9 @@ GPT-2 Transformer Neural Net training loop. See README.md for usage. #include "llmc/adamw.cuh" // defines: global_norm_squared #include "llmc/global_norm.cuh" -// defines: repkv_forward +// defines: repkv_forward, repkv_backward #include "llmc/repkv.cuh" -// defines: precompute_freqs_cis, rope_forward +// defines: precompute_freqs_cis, rope_forward, rope_backward_inplace #include "llmc/rope.cuh" // defines: swiglu_forward, swiglu_backward #include "llmc/swiglu.cuh" @@ -197,7 +197,7 @@ void* malloc_and_point_parameters(ParameterTensors* params, size_t* param_elemen return params_memory; } -constexpr int NUM_ACTIVATION_TENSORS = 21; +constexpr int NUM_ACTIVATION_TENSORS = 22; typedef struct { floatX* encoded; // (B, T, C) floatX* ln1; // (L, B, T, C) @@ -234,6 +234,7 @@ typedef struct { // some additional scratch buffers floatX* scratch_bt4c; // (B, T, 4*C) floatX* scratch_btc; // (B, T, C) + floatX* scratch_bt4c2; // (B, T, 4*C), for simplicify use this one for backward pass too, probably not needed } ActivationTensors; @@ -292,6 +293,7 @@ void fill_in_activation_sizes(const ActivationTensors* data, TensorSpec (&tensor tensors[18] = TENSOR_SPEC(data->output, B * T * max(qkv_channels, max(ffn_channels, max(NH*T, Vp)))); tensors[19] = TENSOR_SPEC(data->scratch_bt4c, B * T * ffn_channels); tensors[20] = TENSOR_SPEC(data->scratch_btc, B * T * C); + tensors[21] = TENSOR_SPEC(data->scratch_bt4c2, B * T * ffn_channels); } void* malloc_and_point_activations(TensorSpec (&tensors)[NUM_ACTIVATION_TENSORS]) { @@ -839,7 +841,6 @@ void gpt2_backward_and_reduce(GPT2 *model, int* inputs, const int* targets, int // get the pointers of the weights for this layer floatX* l_ln1w = params.ln1w + l * C; - floatX* l_ln1b = params.ln1b + l * C; floatX* l_qkvw = params.qkvw + l * qkv_channels * C; floatX* l_attprojw = params.attprojw + l * C * C; floatX* l_ln2w = params.ln2w + l * C; @@ -860,13 +861,11 @@ void gpt2_backward_and_reduce(GPT2 *model, int* inputs, const int* targets, int floatX* dl_fcprojb = grads.fcprojb + l * C; // get the pointers of the activations for this layer floatX* l_ln1 = (model->recompute < 2) ? acts.ln1 + l * B * T * C : acts.lnf; - float* l_ln1_mean = acts.ln1_mean + l * B * T; float* l_ln1_rstd = acts.ln1_rstd + l * B * T; floatX* l_qkvr = acts.qkvr + l * B * T * qkv_channels; floatX* l_atty = acts.atty + l * B * T * C; floatX* l_residual2 = acts.residual2 + l * B * T * C; floatX* l_ln2 = (model->recompute < 2) ? acts.ln2 + l * B * T * C : acts.lnf; - float* l_ln2_mean = acts.ln2_mean + l * B * T; float* l_ln2_rstd = acts.ln2_rstd + l * B * T; floatX* l_fch_pre_gelu = acts.fch + l * B * T * ffn_channels; floatX* l_fch_gelu = (model->recompute < 1) ? acts.fch_gelu + l * B * T * ffn_channels_post_gelu : acts.fch_gelu; @@ -875,6 +874,7 @@ void gpt2_backward_and_reduce(GPT2 *model, int* inputs, const int* targets, int // re-using this memory in every Transformer block as we calculate backward pass floatX* dl_bt4c = (floatX*)model->acts.scratch_bt4c; + floatX* dl_bt4c2 = (floatX*)model->acts.scratch_bt4c2; // same size as dl_bt4c, just a second buffer // start the backward pass for this layer if(model->recompute >= 1) { @@ -886,13 +886,16 @@ void gpt2_backward_and_reduce(GPT2 *model, int* inputs, const int* targets, int // backward the 2nd matmul of MLP matmul_backward(dl_bt4c, dl_fcprojw, dl_fcprojb, dresidual, l_fch_gelu, l_fcprojw, scratchF, B, T, ffn_channels_post_gelu, C, main_stream); // backward the swiglu here, use scratchX to hold the grad because SwiGLU can't be inplace - swiglu_backward(scratchX, dl_bt4c, l_fch_pre_gelu, B, T, ffn_channels_post_gelu, main_stream); + swiglu_backward(dl_bt4c2, dl_bt4c, l_fch_pre_gelu, B, T, ffn_channels_post_gelu, main_stream); // backward the 1st matmul of MLP if(model->recompute >= 2) { // same as gelu above, l_ln1 and l_ln2 are just buffers if recompute >= 2, recompute them here on demand rmsnorm_forward(l_ln2, l_ln2_rstd, l_residual2, l_ln2w, B, T, C, main_stream); } - matmul_backward(dl_btc, dl_fcw, dl_fcb, scratchX, l_ln2, l_fcw, scratchF, B, T, C, ffn_channels, main_stream); + matmul_backward(dl_btc, dl_fcw, dl_fcb, dl_bt4c2, l_ln2, l_fcw, scratchF, B, T, C, ffn_channels, main_stream); + // rmsnorm backward does += to the dresidual, so it correctly accumulates grad from the MLP block above + rmsnorm_backward(dresidual, dl_ln2w, scratchF, dl_btc, l_residual2, l_ln2w, l_ln2_rstd, B, T, C, main_stream); + matmul_backward(dl_btc, dl_attprojw, dl_attprojb, dresidual, l_atty, l_attprojw, scratchF, B, T, C, C, main_stream); // ------------------------------------------------------------------------ // DEBUGGING: we only work until this point right now, so exit here @@ -905,19 +908,18 @@ void gpt2_backward_and_reduce(GPT2 *model, int* inputs, const int* targets, int } // write to .bin file // move output to cpu - floatX* cpu_output = (floatX*)mallocCheck(B*T*C * sizeof(floatX)); - cudaCheck(cudaMemcpy(cpu_output, output, B*T*C * sizeof(floatX), cudaMemcpyDeviceToHost)); + // int sz = B*T*qkv_channels; //B*T*C; + int sz = B*T*C; + floatX* cpu_output = (floatX*)mallocCheck(sz * sizeof(floatX)); + cudaCheck(cudaMemcpy(cpu_output, output, sz * sizeof(floatX), cudaMemcpyDeviceToHost)); FILE* f = fopen("out.bin", "wb"); - fwrite(cpu_output, sizeof(floatX), B*T*C, f); + fwrite(cpu_output, sizeof(floatX), sz, f); fclose(f); exit(0); // ------------------------------------------------------------------------ - // layernorm backward does += to the dresidual, so it correctly accumulates grad from the MLP block above - layernorm_backward(dresidual, dl_ln2w, dl_ln2b, scratchF, dl_btc, l_residual2, l_ln2w, l_ln2_mean, l_ln2_rstd, B, T, C, main_stream); - matmul_backward(dl_btc, dl_attprojw, dl_attprojb, dresidual, l_atty, l_attprojw, scratchF, B, T, C, C, main_stream); - #ifdef ENABLE_CUDNN + printf("cuDNN path TODO\n"); exit(0); float* l_att = (float*)acts.att + l * B * NH * T; // cuDNN needs a smaller FP32 tensor attention_backward_cudnn(dl_bt4c, dl_btc, l_qkvr, l_atty, (float*)l_att, B, T, NH, C, main_stream); #else @@ -927,13 +929,19 @@ void gpt2_backward_and_reduce(GPT2 *model, int* inputs, const int* targets, int floatX* buffer_b = l_fch_pre_gelu; // this is B x T x 4C, so even larger than what we need attention_backward(dl_bt4c, buffer_b, scratchX, buffer_a, dl_btc, l_qkvr, l_att, B, T, C, NH, main_stream); #endif + // backward rope (this can be done in-place) + rope_backward_inplace(dl_bt4c, dl_bt4c, model->freqs_cis, B, T, NH, hd, main_stream); + // backward repkv (use scratchX as gradient buffer here) + repkv_backward(dl_bt4c2, dl_bt4c, B, T, NH, n_kv_head, hd); + + // <--- here the gradients don't match, so there is an issue in between + + // backward QKV projection if(model->recompute >= 2) { - layernorm_forward(l_ln1, l_ln1_mean, l_ln1_rstd, residual, l_ln1w, l_ln1b, B, T, C, main_stream); + rmsnorm_forward(l_ln1, l_ln1_rstd, residual, l_ln1w, B, T, C, main_stream); } - // QKV parameter gradients - matmul_backward(dl_btc, dl_qkvw, dl_qkvb, dl_bt4c, l_ln1, l_qkvw, scratchF, B, T, C, 3 * C, main_stream); - // layernorm backward does += to dresidual, so it correctly accumulates gradient for the Attention block above - layernorm_backward(dresidual, dl_ln1w, dl_ln1b, scratchF, dl_btc, residual, l_ln1w, l_ln1_mean, l_ln1_rstd, B, T, C, main_stream); + matmul_backward(dl_btc, dl_qkvw, dl_qkvb, dl_bt4c2, l_ln1, l_qkvw, scratchF, B, T, C, qkv_channels, main_stream); + rmsnorm_backward(dresidual, dl_ln1w, scratchF, dl_btc, residual, l_ln1w, l_ln1_rstd, B, T, C, main_stream); // Accumulate gradients from this layer in a background stream. if(last_step) { diff --git a/train_llama3.py b/train_llama3.py index 455282e1d..cd1549b42 100644 --- a/train_llama3.py +++ b/train_llama3.py @@ -197,6 +197,12 @@ def forward(self, x, freqs_cis=None, start_pos=None, mask=None): att = F.softmax(scores.float(), dim=-1).type_as(q) y = att @ v # (B, NH, T, T) x (B, NH, T, HD) -> (B, NH, T, HD) y = y.transpose(1, 2).contiguous().view(B, T, C) + + DEBUG_POINT = y.detach() + DEBUG_POINT = DEBUG_POINT.requires_grad_(True) + self.DEBUG_POINT = DEBUG_POINT + y = DEBUG_POINT + y = self.c_proj(y) return y @@ -234,11 +240,7 @@ def __init__(self, config): def forward(self, x, freqs_cis=None, start_pos=None, mask=None): x = x + self.attn(self.ln_1(x), freqs_cis, start_pos, mask) - MLP_INPUT = self.ln_2(x) - MLP_INPUT = MLP_INPUT.detach() - MLP_INPUT.requires_grad = True - self.MLP_INPUT = MLP_INPUT - x = x + self.mlp(MLP_INPUT) + x = x + self.mlp(self.ln_2(x)) return x # ----------------------------------------------------------------------------- @@ -1260,7 +1262,7 @@ def get_lr(it): # --------------------------------------------------------------------- # DEBUGGING: print first 32 elements of x - x = model.transformer.h[-1].MLP_INPUT.grad + x = model.transformer.h[-1].attn.DEBUG_POINT.grad for i in range(32): print("q[{}]: {:.8f}".format(i, x.view(-1)[i].item())) # write to .bin file From 8d490622931b879c76a192bc75127b9b24b2197a Mon Sep 17 00:00:00 2001 From: Andrej Karpathy Date: Fri, 27 Sep 2024 18:07:10 +0000 Subject: [PATCH 34/63] add backward kernel to dev/cuda for rope, to ensure correctness. but i mean, it's trivial. this can't possibly be the issue. it must be the repkv --- dev/cuda/rope.cu | 70 ++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 70 insertions(+) diff --git a/dev/cuda/rope.cu b/dev/cuda/rope.cu index 4e45a4711..d8a457c11 100644 --- a/dev/cuda/rope.cu +++ b/dev/cuda/rope.cu @@ -122,6 +122,66 @@ void rope_forward(int kernel_num, floatX *out, const floatX *inp, const floatX * } } +// ---------------------------------------------------------------------------- +// while we're at it, let's also briefly validate our backward kernel here + +void apply_rotary_emb_backward(float *dinp, const float *dout, const float *inp, const float *freqs_cis, int B, int T, int n_head, int head_dim) { + // backward pass of the RoPE embedding + for (int b = 0; b < B; b++) { + for (int t = 0; t < T; t++) { + int idx_bt = b * (T * n_head * head_dim) + t * (n_head * head_dim); + for (int h = 0; h < n_head; h++) { + int idx_bth = idx_bt + h * head_dim; + for (int d = 0; d < head_dim / 2; d++) { + // fetch the angle from freqs_cis + int freqs_idx = t * head_dim + 2 * d; + float freqs_cos = freqs_cis[freqs_idx]; + float freqs_sin = freqs_cis[freqs_idx + 1]; + // and the input index we'll be updating + int idx = idx_bth + 2 * d; + // backward pass is simple because freqs_cis is just scaling by a constant + dinp[idx] += dout[idx] * freqs_cos + dout[idx + 1] * freqs_sin; + dinp[idx + 1] += -dout[idx] * freqs_sin + dout[idx + 1] * freqs_cos; + } + } + } + } +} + +__global__ void rope_backward_inplace_kernel1(floatX *dinp, const floatX *dout, const floatX *freqs_cis, int B, int T, int n_head, int head_dim) { + int idx = blockIdx.x * blockDim.x + threadIdx.x; + int head_dim_half = head_dim / 2; + if (idx >= B * T * n_head * head_dim_half) return; + // decode the individual indices + int b = idx / (T * n_head * head_dim_half); + int t = (idx / (n_head * head_dim_half)) % T; + int h = (idx / head_dim_half) % n_head; + int d = idx % head_dim_half; + // calculate the index in the input + int idx_bt = b * (T * n_head * head_dim) + t * (n_head * head_dim); + int idx_bth = idx_bt + h * head_dim; + int idxi = idx_bth + 2 * d; // index in the input + // fetch the freqs_cis + int freqs_idx = t * head_dim + 2 * d; + float freqs_cos = freqs_cis[freqs_idx]; + float freqs_sin = freqs_cis[freqs_idx + 1]; + // apply the rotation + float dout_real = (float)dout[idxi]; + float dout_imag = (float)dout[idxi + 1]; + dinp[idxi] = dout_real * freqs_cos + dout_imag * freqs_sin; + dinp[idxi + 1] = -dout_real * freqs_sin + dout_imag * freqs_cos; +} + +void rope_backward_inplace(floatX *dinp, const floatX *dout, const floatX *freqs_cis, int B, int T, int n_head, int head_dim, cudaStream_t stream) { + // backward pass of forward, mirrors the forward kernel in setup and indexing + const int block_size = 128; + int total_threads = B * T * 3 * n_head * head_dim / 2; + int num_blocks = ceil_div(total_threads, block_size); + rope_backward_inplace_kernel1<<>>(dinp, dout, freqs_cis, B, T, n_head, head_dim); + cudaCheck(cudaGetLastError()); +} + +// ---------------------------------------------------------------------------- // tester int main(int argc, char **argv) { srand(0); @@ -179,6 +239,16 @@ int main(int argc, char **argv) { printf("block_size %4d time %.4f ms\n", block_size, elapsed_time); } + // now also briefly validate the backward pass + // first, the reference CPU calculation + float *dinp = (float *)malloc(B * T * n_head * head_dim * sizeof(float)); + memset(dinp, 0, B * T * n_head * head_dim * sizeof(float)); // init at zero + apply_rotary_emb_backward(dinp, out, inp, freqs_cis, B, T, n_head, head_dim); + // now the GPU calculation (note it is done in-place, as we wish it to be to save space) + rope_backward_inplace(d_out, d_out, d_freqs_cis, B, T, n_head, head_dim, 0); + validate_result(d_out, dinp, "dinp", B * T * n_head * head_dim, 1e-5f); + printf("Backward pass result matches.\n"); + // free memory free(inp); free(freqs_cis); From 7d945e994cc105182a3c4d62f0cc8990a62cb5ec Mon Sep 17 00:00:00 2001 From: Andrej Karpathy Date: Fri, 27 Sep 2024 19:25:09 +0000 Subject: [PATCH 35/63] reshuffle repkv a bit, i wrote it from scratch. the kernel is still correct. repkv backward looks correct. rope backward is trivial so i don't see how it's not correct, and i also checked it. basically i'm really confused right now --- dev/cuda/repkv_backward.cu | 52 +++++++++++++++++++++----------------- llmc/repkv.cuh | 4 +-- train_llama3.cu | 43 ++++++++++++++++--------------- train_llama3.py | 12 ++++----- 4 files changed, 60 insertions(+), 51 deletions(-) diff --git a/dev/cuda/repkv_backward.cu b/dev/cuda/repkv_backward.cu index 9a00205d9..84064c530 100644 --- a/dev/cuda/repkv_backward.cu +++ b/dev/cuda/repkv_backward.cu @@ -11,39 +11,46 @@ Block size 128 seems fastest on H100 // cpu reference code void repkv_backward_cpu(float* dinp, const float* dout, - const int B, const int T, const int Cout, - const int hd, const int qh, const int kh, const int vh) { - - assert(Cout == (hd * (3 * qh))); + int B, int T, int C, + int hd, int qh, int kh, int vh) { + // inp is (B, T, C) + // out is (B, T, 3, NH, HD) + // hd = head dimension + // qh, kh, vh = number of query, key, value heads + assert(C == hd * (qh + kh + vh)); assert(kh == vh); int nrep = qh / kh; // number of times to replicate key/value vectors - int Cin = hd * (qh + kh + vh); // output channels + int Cout = hd * (qh * 3); // output channels for (int b = 0; b < B; b++) { for (int t = 0; t < T; t++) { - // seek to the input position dout[b,t,:] - const float* x = dout + b * T * Cout + t * Cout; + // seek to the input position inp[b,t,:] + float* dx = dinp + b * T * C + t * C; // seek to the output position out[b,t,:] - float* y = dinp + b * T * Cin + t * Cin; + const float* dy = dout + b * T * Cout + t * Cout; // copy all the query vectors, no changes - for (int i = 0; i < hd * qh; i++) { y[i] = x[i]; } - x += hd * qh; // advance input pointer - y += hd * qh; // advance output pointer - // copy key vectors, and replicate them nrep times + for (int i = 0; i < hd * qh; i++) { dx[i] = dy[i]; } + dx += hd * qh; // advance input pointer + dy += hd * qh; // advance output pointer + // gather gradients from the key vectors for (int h = 0; h < kh; h++) { + // init the gradient to 0 + for (int i = 0; i < hd; i++) { dx[i] = 0.0f; } for (int n = 0; n < nrep; n++) { - for (int i = 0; i < hd; i++) { y[i] += x[i]; } - x += hd; // advance input pointer + for (int i = 0; i < hd; i++) { dx[i] += dy[i]; } + dy += hd; // advance output pointer } - y += hd; // advance output pointer + dx += hd; // advance input pointer } - // copy value vectors, and replicate them nrep times + // gather gradients from the value vectors for (int h = 0; h < vh; h++) { + // init the gradient to 0 + for (int i = 0; i < hd; i++) { dx[i] = 0.0f; } for (int n = 0; n < nrep; n++) { - for (int i = 0; i < hd; i++) { y[i] += x[i]; } - x += hd; // advance input pointer + for (int i = 0; i < hd; i++) { dx[i] += dy[i]; } + dy += hd; // advance output pointer } - y += hd; // advance output pointer + dx += hd; // advance input pointer } } } @@ -76,7 +83,7 @@ __global__ void repkv_backward_kernel1(floatX* dinp, const floatX* dout, dinp[dinp_idx] = __ldcs(&dout[dout_idx]); } else if (c == 1) { if (nh % replicate_factor == 0) { - float reduced_sum = 0; + float reduced_sum = 0.0f; for (int i = 0; i < replicate_factor; i++) { reduced_sum += __ldcs(&dout[dout_idx+HD*i]); } @@ -87,7 +94,7 @@ __global__ void repkv_backward_kernel1(floatX* dinp, const floatX* dout, } else { if (nh % replicate_factor == 0) { - float reduced_sum = 0; + float reduced_sum = 0.0f; for (int i = 0; i < replicate_factor; i++) { reduced_sum += __ldcs(&dout[dout_idx+HD*i]); } @@ -141,7 +148,6 @@ int main(int argc, char **argv) { // allocate (and fill) CPU memory float* dinp = (float*)malloc(B * T * Cin * sizeof(float)); - memset(dinp, 0, B * T * Cin * sizeof(float)); float* dout = make_random_float(B * T * Cout * sizeof(float)); // allocate GPU memory @@ -160,7 +166,7 @@ int main(int argc, char **argv) { printf("Using kernel %d\n", kernel_num); // CPU reference calculate - repkv_backward_cpu(dinp, dout, B, T, Cout, hd, qh, kh, vh); + repkv_backward_cpu(dinp, dout, B, T, Cin, hd, qh, kh, vh); // check the correctness of the kernel at all block sizes int block_sizes[] = {32, 64, 128, 256, 512, 1024}; diff --git a/llmc/repkv.cuh b/llmc/repkv.cuh index f4c517eaa..a70881402 100644 --- a/llmc/repkv.cuh +++ b/llmc/repkv.cuh @@ -74,7 +74,7 @@ __global__ void repkv_backward_kernel1(floatX* dinp, const floatX* dout, dinp[dinp_idx] = __ldcs(&dout[dout_idx]); } else if (c == 1) { if (nh % replicate_factor == 0) { - float reduced_sum = 0; + float reduced_sum = 0.0f; for (int i = 0; i < replicate_factor; i++) { reduced_sum += (float) __ldcs(&dout[dout_idx+HD*i]); } @@ -85,7 +85,7 @@ __global__ void repkv_backward_kernel1(floatX* dinp, const floatX* dout, } else { if (nh % replicate_factor == 0) { - float reduced_sum = 0; + float reduced_sum = 0.0f; for (int i = 0; i < replicate_factor; i++) { reduced_sum += (float) __ldcs(&dout[dout_idx+HD*i]); } diff --git a/train_llama3.cu b/train_llama3.cu index 2cc554a0b..65739d1c6 100644 --- a/train_llama3.cu +++ b/train_llama3.cu @@ -897,10 +897,31 @@ void gpt2_backward_and_reduce(GPT2 *model, int* inputs, const int* targets, int rmsnorm_backward(dresidual, dl_ln2w, scratchF, dl_btc, l_residual2, l_ln2w, l_ln2_rstd, B, T, C, main_stream); matmul_backward(dl_btc, dl_attprojw, dl_attprojb, dresidual, l_atty, l_attprojw, scratchF, B, T, C, C, main_stream); + // <--- gradient here matches OK + + #ifdef ENABLE_CUDNN + printf("cuDNN path TODO\n"); exit(0); + float* l_att = (float*)acts.att + l * B * NH * T; // cuDNN needs a smaller FP32 tensor + attention_backward_cudnn(dl_bt4c, dl_btc, l_qkvr, l_atty, (float*)l_att, B, T, NH, C, main_stream); + #else + floatX* l_att = acts.att + l * B * NH * T * T; + // we need B x T x (4)C buffers. l_atty and l_fch aren't needed anymore at this point, so reuse their memory + floatX* buffer_a = l_atty; + floatX* buffer_b = l_fch_pre_gelu; // this is B x T x 4C, so even larger than what we need + attention_backward(dl_bt4c, buffer_b, scratchX, buffer_a, dl_btc, l_qkvr, l_att, B, T, C, NH, main_stream); + #endif + // backward rope (this can be done in-place) + rope_backward_inplace(dl_bt4c, dl_bt4c, model->freqs_cis, B, T, NH, hd, main_stream); + // backward repkv (use scratchX as gradient buffer here) + repkv_backward(dl_bt4c2, dl_bt4c, B, T, NH, n_kv_head, hd); + + // <--- here the gradients don't match + // so there is an issue with one of attention, rope, or repkv, or how they are called + // ------------------------------------------------------------------------ // DEBUGGING: we only work until this point right now, so exit here // transfer the first 32 elements to CPU and print them - float* output = (float*)dl_btc; + float* output = (float*)dl_bt4c2; floatX* cpu = (floatX*)mallocCheck(32 * sizeof(floatX)); cudaCheck(cudaMemcpy(cpu, output, 32 * sizeof(floatX), cudaMemcpyDeviceToHost)); for (int i = 0; i < 32; i++) { @@ -909,7 +930,7 @@ void gpt2_backward_and_reduce(GPT2 *model, int* inputs, const int* targets, int // write to .bin file // move output to cpu // int sz = B*T*qkv_channels; //B*T*C; - int sz = B*T*C; + int sz = B*T*qkv_channels; floatX* cpu_output = (floatX*)mallocCheck(sz * sizeof(floatX)); cudaCheck(cudaMemcpy(cpu_output, output, sz * sizeof(floatX), cudaMemcpyDeviceToHost)); FILE* f = fopen("out.bin", "wb"); @@ -918,24 +939,6 @@ void gpt2_backward_and_reduce(GPT2 *model, int* inputs, const int* targets, int exit(0); // ------------------------------------------------------------------------ - #ifdef ENABLE_CUDNN - printf("cuDNN path TODO\n"); exit(0); - float* l_att = (float*)acts.att + l * B * NH * T; // cuDNN needs a smaller FP32 tensor - attention_backward_cudnn(dl_bt4c, dl_btc, l_qkvr, l_atty, (float*)l_att, B, T, NH, C, main_stream); - #else - floatX* l_att = acts.att + l * B * NH * T * T; - // we need B x T x (4)C buffers. l_atty and l_fch aren't needed anymore at this point, so reuse their memory - floatX* buffer_a = l_atty; - floatX* buffer_b = l_fch_pre_gelu; // this is B x T x 4C, so even larger than what we need - attention_backward(dl_bt4c, buffer_b, scratchX, buffer_a, dl_btc, l_qkvr, l_att, B, T, C, NH, main_stream); - #endif - // backward rope (this can be done in-place) - rope_backward_inplace(dl_bt4c, dl_bt4c, model->freqs_cis, B, T, NH, hd, main_stream); - // backward repkv (use scratchX as gradient buffer here) - repkv_backward(dl_bt4c2, dl_bt4c, B, T, NH, n_kv_head, hd); - - // <--- here the gradients don't match, so there is an issue in between - // backward QKV projection if(model->recompute >= 2) { rmsnorm_forward(l_ln1, l_ln1_rstd, residual, l_ln1w, B, T, C, main_stream); diff --git a/train_llama3.py b/train_llama3.py index cd1549b42..b654d8d83 100644 --- a/train_llama3.py +++ b/train_llama3.py @@ -168,6 +168,12 @@ def forward(self, x, freqs_cis=None, start_pos=None, mask=None): B, T, C = x.size() # batch size, sequence length, embedding dimensionality (n_embd) # calculate query, key, values for all heads in batch and move head forward to be the batch dim qkv = self.c_attn(x) + + DEBUG_POINT = qkv.detach() + DEBUG_POINT = DEBUG_POINT.requires_grad_(True) + self.DEBUG_POINT = DEBUG_POINT + qkv = DEBUG_POINT + q, k, v = qkv.split([self.n_head * self.hd, self.n_kv_head * self.hd, self.n_kv_head * self.hd], dim=-1) q, k, v = map(lambda t: t.view(B, T, -1, self.hd), (q, k, v)) # (B, T, NH, HD) q, k = apply_rotary_emb(q, k, freqs_cis=freqs_cis) # rotate QK (rope) <-- 1. difference compared to GPT-2 @@ -197,12 +203,6 @@ def forward(self, x, freqs_cis=None, start_pos=None, mask=None): att = F.softmax(scores.float(), dim=-1).type_as(q) y = att @ v # (B, NH, T, T) x (B, NH, T, HD) -> (B, NH, T, HD) y = y.transpose(1, 2).contiguous().view(B, T, C) - - DEBUG_POINT = y.detach() - DEBUG_POINT = DEBUG_POINT.requires_grad_(True) - self.DEBUG_POINT = DEBUG_POINT - y = DEBUG_POINT - y = self.c_proj(y) return y From e6481b679c4dc3ca7bbd9214a166bde9faea87f2 Mon Sep 17 00:00:00 2001 From: Andrej Karpathy Date: Tue, 1 Oct 2024 16:36:35 +0000 Subject: [PATCH 36/63] fix bug with qkvr sizing, has to be 3*C. Credit to @ademeure for finding this bug and bringing light to darkness and order to chaos. A true warrior in the fight against entropy. --- train_llama3.cu | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/train_llama3.cu b/train_llama3.cu index 65739d1c6..a41946807 100644 --- a/train_llama3.cu +++ b/train_llama3.cu @@ -289,7 +289,7 @@ void fill_in_activation_sizes(const ActivationTensors* data, TensorSpec (&tensor tensors[14] = TENSOR_SPEC(data->lnf_mean, B * T); tensors[15] = TENSOR_SPEC(data->lnf_rstd, B * T); tensors[16] = TENSOR_SPEC(data->losses, B * T); - tensors[17] = TENSOR_SPEC(data->qkvr, L * B * T * qkv_channels); + tensors[17] = TENSOR_SPEC(data->qkvr, L * B * T * 3*C); // 3*C is correct - this is QKV after replication of KV tensors[18] = TENSOR_SPEC(data->output, B * T * max(qkv_channels, max(ffn_channels, max(NH*T, Vp)))); tensors[19] = TENSOR_SPEC(data->scratch_bt4c, B * T * ffn_channels); tensors[20] = TENSOR_SPEC(data->scratch_btc, B * T * C); @@ -678,7 +678,7 @@ void gpt2_forward(GPT2 *model, const int* inputs, size_t B, size_t T) { // get the pointers of the activations for this layer floatX* l_ln1 = (model->recompute < 2) ? acts.ln1 + l * B * T * C : acts.lnf; - floatX* l_qkvr = acts.qkvr + l * B * T * qkv_channels; + floatX* l_qkvr = acts.qkvr + l * B * T * 3*C; floatX* l_atty = acts.atty + l * B * T * C; floatX* l_residual2 = acts.residual2 + l * B * T * C; floatX* l_ln2 = (model->recompute < 2) ? acts.ln2 + l * B * T * C : acts.lnf; @@ -862,7 +862,7 @@ void gpt2_backward_and_reduce(GPT2 *model, int* inputs, const int* targets, int // get the pointers of the activations for this layer floatX* l_ln1 = (model->recompute < 2) ? acts.ln1 + l * B * T * C : acts.lnf; float* l_ln1_rstd = acts.ln1_rstd + l * B * T; - floatX* l_qkvr = acts.qkvr + l * B * T * qkv_channels; + floatX* l_qkvr = acts.qkvr + l * B * T * 3*C; floatX* l_atty = acts.atty + l * B * T * C; floatX* l_residual2 = acts.residual2 + l * B * T * C; floatX* l_ln2 = (model->recompute < 2) ? acts.ln2 + l * B * T * C : acts.lnf; From 9099a0ae9c91c22ad2858d1acbe9d2e84f6eaa07 Mon Sep 17 00:00:00 2001 From: Andrej Karpathy Date: Tue, 1 Oct 2024 17:03:29 +0000 Subject: [PATCH 37/63] ok the full backward now shows max abs diff of 3e-3, except for the encoder backward (that's coming next). i think 3e-3 seems ok just inspecting the differences manually. probably this is correct. encoder backward next --- train_llama3.cu | 52 +++++++++++++++++++++++-------------------------- train_llama3.py | 13 ++++++------- 2 files changed, 30 insertions(+), 35 deletions(-) diff --git a/train_llama3.cu b/train_llama3.cu index a41946807..327802c1d 100644 --- a/train_llama3.cu +++ b/train_llama3.cu @@ -914,31 +914,6 @@ void gpt2_backward_and_reduce(GPT2 *model, int* inputs, const int* targets, int rope_backward_inplace(dl_bt4c, dl_bt4c, model->freqs_cis, B, T, NH, hd, main_stream); // backward repkv (use scratchX as gradient buffer here) repkv_backward(dl_bt4c2, dl_bt4c, B, T, NH, n_kv_head, hd); - - // <--- here the gradients don't match - // so there is an issue with one of attention, rope, or repkv, or how they are called - - // ------------------------------------------------------------------------ - // DEBUGGING: we only work until this point right now, so exit here - // transfer the first 32 elements to CPU and print them - float* output = (float*)dl_bt4c2; - floatX* cpu = (floatX*)mallocCheck(32 * sizeof(floatX)); - cudaCheck(cudaMemcpy(cpu, output, 32 * sizeof(floatX), cudaMemcpyDeviceToHost)); - for (int i = 0; i < 32; i++) { - printf("q[%d] = %.8f\n", i, (float) cpu[i]); - } - // write to .bin file - // move output to cpu - // int sz = B*T*qkv_channels; //B*T*C; - int sz = B*T*qkv_channels; - floatX* cpu_output = (floatX*)mallocCheck(sz * sizeof(floatX)); - cudaCheck(cudaMemcpy(cpu_output, output, sz * sizeof(floatX), cudaMemcpyDeviceToHost)); - FILE* f = fopen("out.bin", "wb"); - fwrite(cpu_output, sizeof(floatX), sz, f); - fclose(f); - exit(0); - // ------------------------------------------------------------------------ - // backward QKV projection if(model->recompute >= 2) { rmsnorm_forward(l_ln1, l_ln1_rstd, residual, l_ln1w, B, T, C, main_stream); @@ -958,15 +933,36 @@ void gpt2_backward_and_reduce(GPT2 *model, int* inputs, const int* targets, int }; const size_t nelem[] = { C, C, - 3 * C * C, 3 * C, + qkv_channels * C, qkv_channels, C * C, C, C, C, - 4 * C * C, 4 * C, - C * 4 * C, C + ffn_channels * C, ffn_channels, + C * ffn_channels_post_gelu, C }; multi_gpu_async_reduce_gradient(pointers, nelem, &multi_gpu_config, main_stream); } } + + // ------------------------------------------------------------------------ + // DEBUGGING: we only work until this point right now, so exit here + // transfer the first 32 elements to CPU and print them + float* output = (float*)dresidual; + floatX* cpu = (floatX*)mallocCheck(32 * sizeof(floatX)); + cudaCheck(cudaMemcpy(cpu, output, 32 * sizeof(floatX), cudaMemcpyDeviceToHost)); + for (int i = 0; i < 32; i++) { + printf("q[%d] = %.8f\n", i, (float) cpu[i]); + } + // write to .bin file + // move output to cpu + int sz = B*T*C; + floatX* cpu_output = (floatX*)mallocCheck(sz * sizeof(floatX)); + cudaCheck(cudaMemcpy(cpu_output, output, sz * sizeof(floatX), cudaMemcpyDeviceToHost)); + FILE* f = fopen("out.bin", "wb"); + fwrite(cpu_output, sizeof(floatX), sz, f); + fclose(f); + exit(0); + // ------------------------------------------------------------------------ + encoder_backward(grads.wte, grads.wpe, scratchX, model->workload_indices, model->bucket_info, dresidual, model->inputs, inputs, B, T, C, random_u32(&model->rng_state), main_stream); diff --git a/train_llama3.py b/train_llama3.py index b654d8d83..6cf9c6f69 100644 --- a/train_llama3.py +++ b/train_llama3.py @@ -168,12 +168,6 @@ def forward(self, x, freqs_cis=None, start_pos=None, mask=None): B, T, C = x.size() # batch size, sequence length, embedding dimensionality (n_embd) # calculate query, key, values for all heads in batch and move head forward to be the batch dim qkv = self.c_attn(x) - - DEBUG_POINT = qkv.detach() - DEBUG_POINT = DEBUG_POINT.requires_grad_(True) - self.DEBUG_POINT = DEBUG_POINT - qkv = DEBUG_POINT - q, k, v = qkv.split([self.n_head * self.hd, self.n_kv_head * self.hd, self.n_kv_head * self.hd], dim=-1) q, k, v = map(lambda t: t.view(B, T, -1, self.hd), (q, k, v)) # (B, T, NH, HD) q, k = apply_rotary_emb(q, k, freqs_cis=freqs_cis) # rotate QK (rope) <-- 1. difference compared to GPT-2 @@ -305,6 +299,11 @@ def forward(self, idx, targets=None, return_logits=True, start_pos=0): freqs_cis = self.freqs_cis[start_pos:start_pos+t] mask = torch.triu(torch.ones((t, t), device=next(self.parameters()).device, dtype=torch.bool), diagonal=1) + DEBUG_POINT = x.detach() + DEBUG_POINT = DEBUG_POINT.requires_grad_(True) + self.DEBUG_POINT = DEBUG_POINT + x = DEBUG_POINT + for i, block in enumerate(self.transformer.h): x = block(x, freqs_cis, start_pos, mask) x = self.transformer.ln_f(x) @@ -1262,7 +1261,7 @@ def get_lr(it): # --------------------------------------------------------------------- # DEBUGGING: print first 32 elements of x - x = model.transformer.h[-1].attn.DEBUG_POINT.grad + x = model.DEBUG_POINT.grad for i in range(32): print("q[{}]: {:.8f}".format(i, x.view(-1)[i].item())) # write to .bin file From c746e06f49e3000039da79fe5a553317fee7e986 Mon Sep 17 00:00:00 2001 From: Andrej Karpathy Date: Tue, 1 Oct 2024 17:19:39 +0000 Subject: [PATCH 38/63] take out debugging stuff. we can now run training loop for both models. they don't match yet --- llmc/encoder.cuh | 13 ++++++++----- train_llama3.cu | 24 ++---------------------- train_llama3.py | 17 ----------------- 3 files changed, 10 insertions(+), 44 deletions(-) diff --git a/llmc/encoder.cuh b/llmc/encoder.cuh index fbaf56af1..6ab94e3d4 100644 --- a/llmc/encoder.cuh +++ b/llmc/encoder.cuh @@ -197,11 +197,14 @@ void encoder_backward(floatX* dwte, floatX* dwpe, floatX* scratch, // gpu output NVTX_RANGE_FN(); // Launch wpe kernel first (so it runs on the GPU in parallel with the CPU pre-processing for wte) - const int block_size = 256; - const int N = T * C / x128::size; - const int grid_size = CEIL_DIV(N, block_size); - wpe_backward_kernel<<>>(dwpe, dout, inp, B, T, C, seed); - cudaCheck(cudaGetLastError()); + // GPT-2 has wpe (absolute positional encoding), but Llama 3 does not as it uses RoPE + if (dwpe != NULL) { + const int block_size = 256; + const int N = T * C / x128::size; + const int grid_size = CEIL_DIV(N, block_size); + wpe_backward_kernel<<>>(dwpe, dout, inp, B, T, C, seed); + cudaCheck(cudaGetLastError()); + } // check the GPU scratch buffer is large enough to hold the bucket info and workload indices // todo - this is trivially true given hardcoded scratch buffer size here, is this useful? diff --git a/train_llama3.cu b/train_llama3.cu index 327802c1d..ccb76cf52 100644 --- a/train_llama3.cu +++ b/train_llama3.cu @@ -943,27 +943,7 @@ void gpt2_backward_and_reduce(GPT2 *model, int* inputs, const int* targets, int } } - // ------------------------------------------------------------------------ - // DEBUGGING: we only work until this point right now, so exit here - // transfer the first 32 elements to CPU and print them - float* output = (float*)dresidual; - floatX* cpu = (floatX*)mallocCheck(32 * sizeof(floatX)); - cudaCheck(cudaMemcpy(cpu, output, 32 * sizeof(floatX), cudaMemcpyDeviceToHost)); - for (int i = 0; i < 32; i++) { - printf("q[%d] = %.8f\n", i, (float) cpu[i]); - } - // write to .bin file - // move output to cpu - int sz = B*T*C; - floatX* cpu_output = (floatX*)mallocCheck(sz * sizeof(floatX)); - cudaCheck(cudaMemcpy(cpu_output, output, sz * sizeof(floatX), cudaMemcpyDeviceToHost)); - FILE* f = fopen("out.bin", "wb"); - fwrite(cpu_output, sizeof(floatX), sz, f); - fclose(f); - exit(0); - // ------------------------------------------------------------------------ - - encoder_backward(grads.wte, grads.wpe, scratchX, model->workload_indices, model->bucket_info, + encoder_backward(grads.wte, NULL, scratchX, model->workload_indices, model->bucket_info, dresidual, model->inputs, inputs, B, T, C, random_u32(&model->rng_state), main_stream); // Aggregate all gradients that are not part of the transformer blocks @@ -977,7 +957,7 @@ void gpt2_backward_and_reduce(GPT2 *model, int* inputs, const int* targets, int cudaCheck(cudaMemcpyAsync(&model->mean_loss, model->accumulated_mean_loss, sizeof(float), cudaMemcpyDeviceToHost, main_stream)); // reduce the gradients for non-transformer block parameters floatX* const pointers[] = {grads.wte, grads.wpe, grads.lnfw, grads.lnfb}; - const size_t nelem[] = {Vp * C, T * C, C, C}; + const size_t nelem[] = {Vp * C, Vp * C, C, C}; multi_gpu_async_reduce_gradient(pointers, nelem, &multi_gpu_config, main_stream); } diff --git a/train_llama3.py b/train_llama3.py index 6cf9c6f69..9a4ee24b3 100644 --- a/train_llama3.py +++ b/train_llama3.py @@ -299,11 +299,6 @@ def forward(self, idx, targets=None, return_logits=True, start_pos=0): freqs_cis = self.freqs_cis[start_pos:start_pos+t] mask = torch.triu(torch.ones((t, t), device=next(self.parameters()).device, dtype=torch.bool), diagonal=1) - DEBUG_POINT = x.detach() - DEBUG_POINT = DEBUG_POINT.requires_grad_(True) - self.DEBUG_POINT = DEBUG_POINT - x = DEBUG_POINT - for i, block in enumerate(self.transformer.h): x = block(x, freqs_cis, start_pos, mask) x = self.transformer.ln_f(x) @@ -1258,18 +1253,6 @@ def get_lr(it): # backward pass if not args.inference_only: loss.backward() - - # --------------------------------------------------------------------- - # DEBUGGING: print first 32 elements of x - x = model.DEBUG_POINT.grad - for i in range(32): - print("q[{}]: {:.8f}".format(i, x.view(-1)[i].item())) - # write to .bin file - with open("ref.bin", "wb") as f: - f.write(x.view(-1).cpu().detach().numpy().tobytes()) - breakpoint() - # --------------------------------------------------------------------- - if ddp: dist.all_reduce(lossf, op=dist.ReduceOp.AVG) lossf = lossf.item() From 2602b46bb3401109c220dc56f60bda53b62a727d Mon Sep 17 00:00:00 2001 From: ademeure Date: Tue, 1 Oct 2024 18:42:38 +0000 Subject: [PATCH 39/63] BF16 opt state (m/v) with stochastic rounding, seems to work really well (OPTIMIZER_LOW_PRECISION=1) --- Makefile | 5 +++++ llmc/adamw.cuh | 22 ++++++++++++++-------- llmc/cuda_common.h | 6 ++++++ llmc/cuda_utils.cuh | 8 ++++---- train_gpt2.cu | 30 +++++++++++++++--------------- train_llama3.cu | 30 +++++++++++++++--------------- 6 files changed, 59 insertions(+), 42 deletions(-) diff --git a/Makefile b/Makefile index dba457ce8..ba7b3f632 100644 --- a/Makefile +++ b/Makefile @@ -243,6 +243,11 @@ else PFLAGS = -DENABLE_BF16 endif +# Optimizer precision settings, enable to allow BF16 for AdamW m/v state (also affects state file) +ifeq ($(OPTIMIZER_LOW_PRECISION), 1) + PFLAGS += -DOPTIMIZER_LOW_PRECISION +endif + # PHONY means these targets will always be executed .PHONY: all train_llama3cu train_gpt2 test_gpt2 train_gpt2cu test_gpt2cu train_gpt2fp32cu test_gpt2fp32cu profile_gpt2cu diff --git a/llmc/adamw.cuh b/llmc/adamw.cuh index 4453576ee..1986d8287 100644 --- a/llmc/adamw.cuh +++ b/llmc/adamw.cuh @@ -16,22 +16,28 @@ __device__ float lerp(float start, float end, float weight) { } template -__device__ void adamw_update(Tp* params_memory, float* master_params_memory, Tg* grads_memory, float* m_memory, float* v_memory, size_t num_parameters, +__device__ void adamw_update(Tp* params_memory, float* master_params_memory, Tg* grads_memory, floatOpt* m_memory, floatOpt* v_memory, size_t num_parameters, float learning_rate, float beta1, float beta2, float beta1_correction, float beta2_correction, float eps, float weight_decay, float grad_scale, unsigned int seed) { int idx = blockIdx.x * blockDim.x + threadIdx.x; if (idx >= num_parameters) { return; } // guard + // random number generation (reuse same rng shifted, since 32 bits is overkill for FP32->BF16) + // note this all gets optimised away by the compiler if everything is FP32 + unsigned int random = Get2dNoiseUint(idx, blockIdx.y, seed); + unsigned int random_m = __funnelshift_l(random, random, 10); // rotate by 10 bits + unsigned int random_v = __funnelshift_l(random, random, 20); // rotate by 20 bits + // get the gradient, m, and v for this parameter float grad = grad_scale * (float)grads_memory[idx]; - float m = m_memory[idx]; - float v = v_memory[idx]; + float m = (float)m_memory[idx]; + float v = (float)v_memory[idx]; // update the first moment (momentum) m = lerp(grad, m, beta1); - m_memory[idx] = m; + stochastic_rounding(m, &m_memory[idx], random_m, false); // update the second moment (RMSprop) v = lerp(grad * grad, v, beta2); - v_memory[idx] = v; + stochastic_rounding(v, &v_memory[idx], random_v, false); m /= beta1_correction; // m_hat v /= beta2_correction; // v_hat // fetch the old value of this parameter as a float, from either source @@ -40,14 +46,14 @@ __device__ void adamw_update(Tp* params_memory, float* master_params_memory, Tg* float param = old_param - (learning_rate * (m / (sqrtf(v) + eps) + weight_decay * old_param)); // update our low precision version of the parameters using stochastic rounding // this will be used in the next forward pass - stochastic_rounding(param, ¶ms_memory[idx], seed); + stochastic_rounding(param, ¶ms_memory[idx], random, false); // write the full, float version of the param into our master copy, if we maintain one // this will be used in the next update if (master_params_memory != NULL) { master_params_memory[idx] = param; } } template -__global__ void adamw_kernel3(Tp* params_memory, float* master_params_memory, Tg* grads_memory, float* m_memory, float* v_memory, size_t num_parameters, +__global__ void adamw_kernel3(Tp* params_memory, float* master_params_memory, Tg* grads_memory, floatOpt* m_memory, floatOpt* v_memory, size_t num_parameters, ptrdiff_t w_stride, ptrdiff_t g_stride, ptrdiff_t s_stride, float learning_rate, float beta1, float beta2, float beta1_correction, float beta2_correction, float eps, float weight_decay, float grad_scale, unsigned int seed) { @@ -72,7 +78,7 @@ __global__ void init_from_master_kernel(Tp* params_memory, float* master_params_ } template -void adamw_update(Tp* params_memory, float* master_params_memory, Tg* grads_memory, float* m_memory, float* v_memory, size_t num_parameters, +void adamw_update(Tp* params_memory, float* master_params_memory, Tg* grads_memory, floatOpt* m_memory, floatOpt* v_memory, size_t num_parameters, ptrdiff_t w_stride, ptrdiff_t g_stride, ptrdiff_t s_stride, int num_slices, float learning_rate, float beta1, float beta2, int t, float eps, float weight_decay, float grad_scale, unsigned int seed, cudaStream_t stream) { // AdamW update diff --git a/llmc/cuda_common.h b/llmc/cuda_common.h index 6f5bf6564..e81c60a8b 100644 --- a/llmc/cuda_common.h +++ b/llmc/cuda_common.h @@ -91,6 +91,12 @@ typedef __nv_bfloat16 floatX; #define PRECISION_MODE PRECISION_BF16 #endif +#if defined(OPTIMIZER_LOW_PRECISION) +typedef floatX floatOpt; +#else +typedef float floatOpt; +#endif + // ---------------------------------------------------------------------------- // Load and store with streaming cache hints // Older nvcc does not provide __ldcs and __stcs for bfloat16, despite these diff --git a/llmc/cuda_utils.cuh b/llmc/cuda_utils.cuh index 030ec073e..d10ad0a67 100644 --- a/llmc/cuda_utils.cuh +++ b/llmc/cuda_utils.cuh @@ -266,20 +266,20 @@ __device__ __host__ constexpr unsigned int Get2dNoiseUint(int indexX, int indexY } // stochastic rounding built on top of Squirel Noise above (with seed updated per step via xorshift) -__device__ __forceinline__ void stochastic_rounding(float in, __nv_bfloat16 *out, unsigned int seed) { +__device__ __forceinline__ void stochastic_rounding(float in, __nv_bfloat16 *out, unsigned int seed, bool noise=true) { // todo - is this stochastic rounding *too good*? can we cut any corners? // makes sure each thread gets a different random number - unsigned int random = Get2dNoiseUint(threadIdx.x, blockIdx.x * blockDim.x + blockIdx.y, seed); + unsigned int random = noise ? Get2dNoiseUint(threadIdx.x, blockIdx.x * blockDim.x + blockIdx.y, seed) : seed; unsigned int threshold = random & 0xFFFF; unsigned int float_bits = __float_as_uint(in); unsigned int rounded_bits = float_bits & 0x0000FFFF; float_bits = (rounded_bits > threshold) ? (float_bits | 0xFFFF) : (float_bits & ~0xFFFF); *out = __float2bfloat16_rn(__uint_as_float(float_bits)); } -__device__ __forceinline__ void stochastic_rounding(float in, half *out, unsigned int random) { +__device__ __forceinline__ void stochastic_rounding(float in, half *out, unsigned int random, bool noise=true) { *out = (float)in; // todo - implement this... } -__device__ __forceinline__ void stochastic_rounding(float in, float *out, unsigned int random) { +__device__ __forceinline__ void stochastic_rounding(float in, float *out, unsigned int random, bool noise=true) { *out = in; // dummy function for when floatX is float (FP32 mode) } diff --git a/train_gpt2.cu b/train_gpt2.cu index 70d8d0c5a..d1af34f21 100644 --- a/train_gpt2.cu +++ b/train_gpt2.cu @@ -295,8 +295,8 @@ typedef struct { ParameterTensors grads; void* grads_memory; // buffers for the AdamW optimizer - float* m_memory; - float* v_memory; + floatOpt* m_memory; + floatOpt* v_memory; float* master_weights; // is NULL unless fp32 weights is enabled. // the activations of the model, and their sizes ActivationTensors acts; @@ -395,12 +395,12 @@ void gpt2_allocate_state(GPT2 *model, int B, int T) { // we will now init the optimizer states and master weights // this is usually a substantial amount of memory allocation right here. size_t shard_num_parameters = multi_gpu_config.shard_num_parameters; // num parameters we are responsible for - printf0("allocating %zu MiB for AdamW optimizer state m\n", (shard_num_parameters * sizeof(float)) >> 20); - printf0("allocating %zu MiB for AdamW optimizer state v\n", (shard_num_parameters * sizeof(float)) >> 20); + printf0("allocating %zu MiB for AdamW optimizer state m\n", (shard_num_parameters * sizeof(floatOpt)) >> 20); + printf0("allocating %zu MiB for AdamW optimizer state v\n", (shard_num_parameters * sizeof(floatOpt)) >> 20); assert(model->m_memory == nullptr); assert(model->v_memory == nullptr); - memory_status |= cudaMallocConditionallyManaged((void**)&model->m_memory, shard_num_parameters * sizeof(float)); - memory_status |= cudaMallocConditionallyManaged((void**)&model->v_memory, shard_num_parameters * sizeof(float)); + memory_status |= cudaMallocConditionallyManaged((void**)&model->m_memory, shard_num_parameters * sizeof(floatOpt)); + memory_status |= cudaMallocConditionallyManaged((void**)&model->v_memory, shard_num_parameters * sizeof(floatOpt)); if (model->use_master_weights == 1) { assert(model->master_weights == nullptr); @@ -1050,8 +1050,8 @@ void gpt2_update(GPT2 *model, float learning_rate, float beta1, float beta2, flo if(init_state) { model->init_state = false; NvtxRange rng("InitOpt"); - cudaCheck(cudaMemset(model->m_memory, 0, multi_gpu_config->shard_num_parameters * sizeof(float))); - cudaCheck(cudaMemset(model->v_memory, 0, multi_gpu_config->shard_num_parameters * sizeof(float))); + cudaCheck(cudaMemset(model->m_memory, 0, multi_gpu_config->shard_num_parameters * sizeof(floatOpt))); + cudaCheck(cudaMemset(model->v_memory, 0, multi_gpu_config->shard_num_parameters * sizeof(floatOpt))); } // save RNG state at this point so we can round from master weights identically when restoring from a checkpoint @@ -1082,8 +1082,8 @@ void gpt2_update(GPT2 *model, float learning_rate, float beta1, float beta2, flo floatX* grad_ptr = (floatX*)model->grads_memory + local_offset_full; ptrdiff_t opt_state_offset = multi_gpu_config->zero_stage < 1 ? local_offset_full : local_offset_partial; - float* m_ptr = model->m_memory + opt_state_offset; - float* v_ptr = model->v_memory + opt_state_offset; + floatOpt* m_ptr = model->m_memory + opt_state_offset; + floatOpt* v_ptr = model->v_memory + opt_state_offset; float* master_ptr = nullptr; if (model->master_weights != nullptr) { master_ptr = model->master_weights + opt_state_offset; } if(init_state && model->master_weights != nullptr ) { @@ -1228,10 +1228,10 @@ void save_state(const char* filename, int step, GPT2* model, DataLoader* loader) *((size_t*)&state_header[32]) = loader->current_sample_idx; // position in shard fwriteCheck(state_header, sizeof(int), 256, state_file); - // write AdamW m, v, and master_weights here (they are all float) + // write AdamW m, v, and master_weights here (they are all float, unless OPTIMIZER_LOW_PRECISION is defined) size_t shard_num_parameters = multi_gpu_config.shard_num_parameters; - device_to_file(state_file, model->m_memory, shard_num_parameters * sizeof(float), IO_BUF_SIZE, main_stream); - device_to_file(state_file, model->v_memory, shard_num_parameters * sizeof(float), IO_BUF_SIZE, main_stream); + device_to_file(state_file, model->m_memory, shard_num_parameters * sizeof(floatOpt), IO_BUF_SIZE, main_stream); + device_to_file(state_file, model->v_memory, shard_num_parameters * sizeof(floatOpt), IO_BUF_SIZE, main_stream); if(model->use_master_weights) { device_to_file(state_file, model->master_weights, shard_num_parameters * sizeof(float), IO_BUF_SIZE, main_stream); } @@ -1276,8 +1276,8 @@ void load_state(int* step, GPT2* model, DataLoader* loader, const char* filename model->init_state = false; // we just got the state from file, no need to do first-touch init assert(model->m_memory != nullptr); assert(model->v_memory != nullptr); - file_to_device(model->m_memory, state_file, shard_num_parameters * sizeof(float), IO_BUF_SIZE, main_stream); - file_to_device(model->v_memory, state_file, shard_num_parameters * sizeof(float), IO_BUF_SIZE, main_stream); + file_to_device(model->m_memory, state_file, shard_num_parameters * sizeof(floatOpt), IO_BUF_SIZE, main_stream); + file_to_device(model->v_memory, state_file, shard_num_parameters * sizeof(floatOpt), IO_BUF_SIZE, main_stream); if(model->use_master_weights) { assert(model->master_weights != nullptr); file_to_device(model->master_weights, state_file, shard_num_parameters * sizeof(float), IO_BUF_SIZE, main_stream); diff --git a/train_llama3.cu b/train_llama3.cu index ccb76cf52..d28a38983 100644 --- a/train_llama3.cu +++ b/train_llama3.cu @@ -334,8 +334,8 @@ typedef struct { ParameterTensors grads; void* grads_memory; // buffers for the AdamW optimizer - float* m_memory; - float* v_memory; + floatOpt* m_memory; + floatOpt* v_memory; float* master_weights; // is NULL unless fp32 weights is enabled. // the activations of the model, and their sizes ActivationTensors acts; @@ -455,12 +455,12 @@ void gpt2_allocate_state(GPT2 *model, int B, int T) { // we will now init the optimizer states and master weights // this is usually a substantial amount of memory allocation right here. size_t shard_num_parameters = multi_gpu_config.shard_num_parameters; // num parameters we are responsible for - printf0("allocating %zu MiB for AdamW optimizer state m\n", (shard_num_parameters * sizeof(float)) >> 20); - printf0("allocating %zu MiB for AdamW optimizer state v\n", (shard_num_parameters * sizeof(float)) >> 20); + printf0("allocating %zu MiB for AdamW optimizer state m\n", (shard_num_parameters * sizeof(floatOpt)) >> 20); + printf0("allocating %zu MiB for AdamW optimizer state v\n", (shard_num_parameters * sizeof(floatOpt)) >> 20); assert(model->m_memory == nullptr); assert(model->v_memory == nullptr); - memory_status |= cudaMallocConditionallyManaged((void**)&model->m_memory, shard_num_parameters * sizeof(float)); - memory_status |= cudaMallocConditionallyManaged((void**)&model->v_memory, shard_num_parameters * sizeof(float)); + memory_status |= cudaMallocConditionallyManaged((void**)&model->m_memory, shard_num_parameters * sizeof(floatOpt)); + memory_status |= cudaMallocConditionallyManaged((void**)&model->v_memory, shard_num_parameters * sizeof(floatOpt)); if (model->use_master_weights == 1) { assert(model->master_weights == nullptr); @@ -1047,8 +1047,8 @@ void gpt2_update(GPT2 *model, float learning_rate, float beta1, float beta2, flo if(init_state) { model->init_state = false; NvtxRange rng("InitOpt"); - cudaCheck(cudaMemset(model->m_memory, 0, multi_gpu_config->shard_num_parameters * sizeof(float))); - cudaCheck(cudaMemset(model->v_memory, 0, multi_gpu_config->shard_num_parameters * sizeof(float))); + cudaCheck(cudaMemset(model->m_memory, 0, multi_gpu_config->shard_num_parameters * sizeof(floatOpt))); + cudaCheck(cudaMemset(model->v_memory, 0, multi_gpu_config->shard_num_parameters * sizeof(floatOpt))); } // save RNG state at this point so we can round from master weights identically when restoring from a checkpoint @@ -1079,8 +1079,8 @@ void gpt2_update(GPT2 *model, float learning_rate, float beta1, float beta2, flo floatX* grad_ptr = (floatX*)model->grads_memory + local_offset_full; ptrdiff_t opt_state_offset = multi_gpu_config->zero_stage < 1 ? local_offset_full : local_offset_partial; - float* m_ptr = model->m_memory + opt_state_offset; - float* v_ptr = model->v_memory + opt_state_offset; + floatOpt* m_ptr = model->m_memory + opt_state_offset; + floatOpt* v_ptr = model->v_memory + opt_state_offset; float* master_ptr = nullptr; if (model->master_weights != nullptr) { master_ptr = model->master_weights + opt_state_offset; } if(init_state && model->master_weights != nullptr ) { @@ -1225,10 +1225,10 @@ void save_state(const char* filename, int step, GPT2* model, DataLoader* loader) *((size_t*)&state_header[32]) = loader->current_sample_idx; // position in shard fwriteCheck(state_header, sizeof(int), 256, state_file); - // write AdamW m, v, and master_weights here (they are all float) + // write AdamW m, v, and master_weights here (they are all float, unless OPTIMIZER_LOW_PRECISION is defined) size_t shard_num_parameters = multi_gpu_config.shard_num_parameters; - device_to_file(state_file, model->m_memory, shard_num_parameters * sizeof(float), IO_BUF_SIZE, main_stream); - device_to_file(state_file, model->v_memory, shard_num_parameters * sizeof(float), IO_BUF_SIZE, main_stream); + device_to_file(state_file, model->m_memory, shard_num_parameters * sizeof(floatOpt), IO_BUF_SIZE, main_stream); + device_to_file(state_file, model->v_memory, shard_num_parameters * sizeof(floatOpt), IO_BUF_SIZE, main_stream); if(model->use_master_weights) { device_to_file(state_file, model->master_weights, shard_num_parameters * sizeof(float), IO_BUF_SIZE, main_stream); } @@ -1273,8 +1273,8 @@ void load_state(int* step, GPT2* model, DataLoader* loader, const char* filename model->init_state = false; // we just got the state from file, no need to do first-touch init assert(model->m_memory != nullptr); assert(model->v_memory != nullptr); - file_to_device(model->m_memory, state_file, shard_num_parameters * sizeof(float), IO_BUF_SIZE, main_stream); - file_to_device(model->v_memory, state_file, shard_num_parameters * sizeof(float), IO_BUF_SIZE, main_stream); + file_to_device(model->m_memory, state_file, shard_num_parameters * sizeof(floatOpt), IO_BUF_SIZE, main_stream); + file_to_device(model->v_memory, state_file, shard_num_parameters * sizeof(floatOpt), IO_BUF_SIZE, main_stream); if(model->use_master_weights) { assert(model->master_weights != nullptr); file_to_device(model->master_weights, state_file, shard_num_parameters * sizeof(float), IO_BUF_SIZE, main_stream); From 2c5ced6a77fe820a25c58c82a6ad1fe724e86923 Mon Sep 17 00:00:00 2001 From: Andrej Karpathy Date: Tue, 1 Oct 2024 21:41:17 +0000 Subject: [PATCH 40/63] fix bug due to bf16 adamw mv --- llmc/adamw.cuh | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/llmc/adamw.cuh b/llmc/adamw.cuh index 1986d8287..fa8d7a5bb 100644 --- a/llmc/adamw.cuh +++ b/llmc/adamw.cuh @@ -74,7 +74,8 @@ __global__ void init_from_master_kernel(Tp* params_memory, float* master_params_ if (idx >= num_parameters) { return; } params_memory += blockIdx.y * w_stride; // adjust for layer offset master_params_memory += blockIdx.y * s_stride; - stochastic_rounding(master_params_memory[idx], ¶ms_memory[idx], seed); + unsigned int random = Get2dNoiseUint(idx, blockIdx.y, seed); + stochastic_rounding(master_params_memory[idx], ¶ms_memory[idx], random, false); } template From 3745dac6deac478b77f656a96cb939974310fe05 Mon Sep 17 00:00:00 2001 From: Erik Schultheis Date: Sun, 13 Apr 2025 17:23:57 +0200 Subject: [PATCH 41/63] define llama3.2 1B and 3B for export from python (will untie embeddings and lm-head for now) --- train_llama3.py | 74 ++++++++++++++++++++++++++++++++++++------------- 1 file changed, 54 insertions(+), 20 deletions(-) diff --git a/train_llama3.py b/train_llama3.py index 9a4ee24b3..8caaaf9f7 100644 --- a/train_llama3.py +++ b/train_llama3.py @@ -242,14 +242,15 @@ def forward(self, x, freqs_cis=None, start_pos=None, mask=None): @dataclass class LlamaConfig: - version: str = "3.1" + version: str + n_layer: int + n_head: int + n_embd: int + ffn_dim_multiplier: float + tied_embeddings: bool block_size: int = 8192 vocab_size: int = 128256 - n_layer: int = 32 - n_head: int = 32 n_kv_head: int = 8 - n_embd: int = 4096 - ffn_dim_multiplier: float = 1.3 multiple_of: int = 1024 norm_eps: float = 1e-5 rope_theta: float = 500000.0 @@ -258,14 +259,47 @@ class LlamaConfig: use_kv: bool = True flash: bool = False # use flashattention? - def __init__(self, **kwargs): - for k, v in kwargs.items(): - if hasattr(self, k): - setattr(self, k, v) + def __post_init__(self): assert self.n_kv_head <= self.n_head assert self.n_head % self.n_kv_head == 0 assert self.n_embd % self.n_head == 0 + +LLama3_8BConfig = LlamaConfig( + version="3.1", + n_layer=32, + n_head=32, + n_embd=4096, + ffn_dim_multiplier=1.3, + tied_embeddings=False +) + +LLama3_3BConfig = LlamaConfig( + version="3.2", + n_layer=28, + n_head=24, + n_embd=3072, + ffn_dim_multiplier=1.0, + tied_embeddings=True +) + +LLama3_1BConfig = LlamaConfig( + version="3.2", + n_layer=16, + n_head=32, + n_embd=2048, + ffn_dim_multiplier=1.4, + tied_embeddings=True +) + + +MODEL_DICT: Dict[str, LlamaConfig] = { + "meta-llama/Meta-Llama-3.1-8B": LLama3_8BConfig, + "meta-llama/Llama-3.2-3B": LLama3_3BConfig, + "meta-llama/Llama-3.2-1B": LLama3_1BConfig, +} + + class LLaMA(nn.Module): def __init__(self, config): @@ -401,8 +435,7 @@ def unpermute(w, n_heads, dim1, dim2): def from_pretrained_llama3_hf(cls, model_id): """Loads pretrained LLaMA model weights from HuggingFace""" from transformers import AutoModelForCausalLM, AutoTokenizer - assert model_id == "meta-llama/Meta-Llama-3.1-8B", "Only the 8B-base model is supported for now" - model_args = LlamaConfig() + model_args = MODEL_DICT[model_id] model = AutoModelForCausalLM.from_pretrained(model_id) checkpoint = LLaMA.adapt_llama_state_dict_keys_hf(model.state_dict(), model_args) @@ -422,7 +455,7 @@ def from_pretrained_llama3_hf(cls, model_id): @classmethod def from_pretrained_llama3_meta(cls, ckpt_dir, tokenizer_path): """Loads pretrained LLaMA model weights from a checkpoint directory""" - model_args = LlamaConfig() + model_args = LLama3_8BConfig ckpt_path = sorted(Path(ckpt_dir).glob("*.pth"))[0] checkpoint = torch.load(ckpt_path, map_location="cpu", weights_only=True) @@ -942,8 +975,9 @@ def write_state(model, x, y, logits, loss, filename): # this can be used for checking the computation correctness in C header = torch.zeros(256, dtype=torch.int32) header[0] = 20240803 # magic - header[1] = x.size(0) # batch size of the batch, B - header[2] = x.size(1) # temporal extent of the batch, T + header[1] = 2 # version + header[2] = x.size(0) # batch size of the batch, B + header[3] = x.size(1) # temporal extent of the batch, T grads = {name: param.grad.cpu() for name, param in model.named_parameters()} with open(filename, "wb") as file: # header @@ -982,7 +1016,7 @@ def print0(*args, **kwargs): parser.add_argument("--input_bin", type=str, default="dev/data/tinyshakespeare/tiny_shakespeare_val.bin", help="input .bin to train on") parser.add_argument("--input_val_bin", type=str, default="", help="input .bin to eval validation loss on") parser.add_argument("--output_dir", type=str, default="", help="output directory to which to write logs and checkpoints") - parser.add_argument("--model", type=str, default="meta-llama/Meta-Llama-3.1-8B", help="chose the llama model") + parser.add_argument("--model", type=str, default="meta-llama/Llama-3.2-1B", help="chose the llama model") # token layout for each step of the optimization parser.add_argument("--batch_size", type=int, default=4, help="batch size, in units of #batch dimensions") parser.add_argument("--sequence_length", type=int, default=64, help="sequence length") @@ -1017,7 +1051,7 @@ def print0(*args, **kwargs): B, T = args.batch_size, args.sequence_length assert 1 <= T <= 8192, "sequence length must be between 1 and 8192" assert args.dtype in {"float32", "float16", "bfloat16"} - assert args.model in {"meta-llama/Meta-Llama-3.1-8B"} # only 8B base model supported for now + assert args.model in MODEL_DICT.keys() # create the logging directory if it does not exist logfile = None @@ -1123,10 +1157,10 @@ def print0(*args, **kwargs): logits, loss = model(x, y) loss.backward() # save model params, in bfloat16 - model_to_size = {"meta-llama/Meta-Llama-3.1-8B": "8B"} - model_size_str = model_to_size[args.model] # e.g. "8B" - write_model(model, os.path.join(args.output_dir, f"llama3.1_{model_size_str}.bin"), dtype="float32") - write_model(model, os.path.join(args.output_dir, f"llama3.1_{model_size_str}_bf16.bin"), dtype="bfloat16") + model_size_str = args.model.split("-")[-1] + model_version = MODEL_DICT[args.model].version + write_model(model, os.path.join(args.output_dir, f"llama{model_version}_{model_size_str}.bin"), dtype="float32") + write_model(model, os.path.join(args.output_dir, f"llama{model_version}_{model_size_str}_bf16.bin"), dtype="bfloat16") # save x, y, logits, loss, and parameter gradients, for debugging C # always store these in fp32 to have an accurate reference (?) write_state(model, x, y, logits, loss, os.path.join(args.output_dir, f"llama3_{model_size_str}_debug_state.bin")) From 4d7980c26f3aaebc53c7e3e24287b6251c03df42 Mon Sep 17 00:00:00 2001 From: Erik Schultheis Date: Sun, 13 Apr 2025 17:25:57 +0200 Subject: [PATCH 42/63] renaming gpt2 -> llama3 --- Makefile | 5 +- test_llama3.cu | 68 ++++++++++----------- train_llama3.cu | 153 ++++++++++++++++++++++++------------------------ 3 files changed, 114 insertions(+), 112 deletions(-) diff --git a/Makefile b/Makefile index ba7b3f632..1d7271f1e 100644 --- a/Makefile +++ b/Makefile @@ -249,7 +249,7 @@ ifeq ($(OPTIMIZER_LOW_PRECISION), 1) endif # PHONY means these targets will always be executed -.PHONY: all train_llama3cu train_gpt2 test_gpt2 train_gpt2cu test_gpt2cu train_gpt2fp32cu test_gpt2fp32cu profile_gpt2cu +.PHONY: all train_llama3cu test_llama3cu train_gpt2 test_gpt2 train_gpt2cu test_gpt2cu train_gpt2fp32cu test_gpt2fp32cu profile_gpt2cu # Add targets TARGETS = train_gpt2 test_gpt2 @@ -293,6 +293,9 @@ profile_gpt2cu: profile_gpt2.cu $(NVCC_CUDNN) train_llama3cu: train_llama3.cu $(NVCC_CUDNN) $(NVCC) $(NVCC_FLAGS) $(PFLAGS) $^ $(NVCC_LDFLAGS) $(NVCC_INCLUDES) $(NVCC_LDLIBS) $(CUDA_OUTPUT_FILE) +test_llama3cu: test_llama3.cu $(NVCC_CUDNN) + $(NVCC) $(NVCC_FLAGS) $(PFLAGS) $^ $(NVCC_LDFLAGS) $(NVCC_INCLUDES) $(NVCC_LDLIBS) $(CUDA_OUTPUT_FILE) + clean: $(REMOVE_FILES) $(TARGETS) $(REMOVE_BUILD_OBJECT_FILES) diff --git a/test_llama3.cu b/test_llama3.cu index e608ce229..6ab6fae83 100644 --- a/test_llama3.cu +++ b/test_llama3.cu @@ -1,5 +1,5 @@ #define TESTING -#include "train_gpt2.cu" +#include "train_llama3.cu" // poor man's tensor checker int check_tensor(float *a, float *b, int n, const char* label, float threshold=1e-0) { @@ -48,7 +48,7 @@ int check_tensor(float *a, float *b, int n, const char* label, float threshold=1 // the same tensors as in the train file, but in float, which are used as reference typedef struct { float* wte; // (Vp, C) - float* wpe; // (maxT, C) + float* wlmhead; // (Vp, C) float* ln1w; // (L, C) float* ln1b; // (L, C) float* qkvw; // (L, 3*C, C) @@ -76,7 +76,7 @@ float* float_cpu_malloc_and_point_parameters(FloatParameterTensors* params, size // everything is float so number of bytes to allocate is a simple multiplication float* params_memory = (float*)mallocCheck(num_parameters * sizeof(float)); float** ptrs[] = { - ¶ms->wte, ¶ms->wpe, ¶ms->ln1w, ¶ms->ln1b, ¶ms->qkvw, ¶ms->qkvb, + ¶ms->wte, ¶ms->wlmhead, ¶ms->ln1w, ¶ms->ln1b, ¶ms->qkvw, ¶ms->qkvb, ¶ms->attprojw, ¶ms->attprojb, ¶ms->ln2w, ¶ms->ln2b, ¶ms->fcw, ¶ms->fcb, ¶ms->fcprojw, ¶ms->fcprojb, ¶ms->lnfw, ¶ms->lnfb }; @@ -100,16 +100,16 @@ int main(int argc, char *argv[]) { // set the right paths #if defined(ENABLE_BF16) - const char* load_filename = "gpt2_124M_bf16.bin"; + const char* load_filename = "llama3.2_1B_bf16.bin"; #else - const char* load_filename = "gpt2_124M.bin"; + const char* load_filename = "llama3.2_1B.bin"; #endif // build the GPT-2 model from a checkpoint - GPT2 model; - gpt2_init_common(&model); + LLama3 model; + llama3_init_common(&model); - gpt2_build_from_checkpoint(&model, load_filename); + llama3_build_from_checkpoint(&model, load_filename); size_t V = model.config.vocab_size; size_t Vp = model.config.padded_vocab_size; size_t maxT = model.config.max_seq_len; @@ -126,13 +126,13 @@ int main(int argc, char *argv[]) { } // load additional information that we will use for debugging and error checking - FILE *state_file = fopenCheck("gpt2_124M_debug_state.bin", "rb"); + FILE *state_file = fopenCheck("llama3_1B_debug_state.bin", "rb"); int state_header[256]; freadCheck(state_header, sizeof(int), 256, state_file); - if (state_header[0] != 20240327) { fprintf(stderr, "Bad magic state file\n"); exit(EXIT_FAILURE); } + if (state_header[0] != 20240803) { fprintf(stderr, "Bad magic state file\n"); exit(EXIT_FAILURE); } if (state_header[1] != 2) { - fprintf(stderr, "Bad version in state file\n"); - fprintf(stderr, "---> HINT: try to re-run `python train_gpt2.py`\n"); + fprintf(stderr, "Bad version in state file: %d\n", state_header[1]); + fprintf(stderr, "---> HINT: try to re-run `python train_llama3.py`\n"); exit(EXIT_FAILURE); } int B = state_header[2]; // batch size, e.g. 4 @@ -168,10 +168,10 @@ int main(int argc, char *argv[]) { // overall OK signal for the test int allok = 1; - gpt2_allocate_state(&model, B, T); + llama3_allocate_state(&model, B, T); // First, do target-free forward pass to validate logits - gpt2_forward(&model, x, B, T); + llama3_forward(&model, x, B, T); // at this point, target should be equal to expected_logits, let's compare // copy logits to CPU so we can compare them floatX* logits_cpu_raw = (floatX*)mallocCheck(B * T * Vp * sizeof(floatX)); @@ -220,8 +220,8 @@ int main(int argc, char *argv[]) { for (int step = 0; step < 10; step++) { struct timespec start, end; clock_gettime(CLOCK_MONOTONIC, &start); - gpt2_forward(&model, x, B, T); - gpt2_backward_and_reduce(&model, x, y, 1, 0); + llama3_forward(&model, x, B, T); + llama3_backward_and_reduce(&model, x, y, 1, 0); clock_gettime(CLOCK_MONOTONIC, &end); double time_elapsed_s = (end.tv_sec - start.tv_sec) + (end.tv_nsec - start.tv_nsec) / 1e9; @@ -275,7 +275,7 @@ int main(int argc, char *argv[]) { #endif allok = allok & check_tensor(tensors1[0], tensors2[0], V * C, "wte", grad_thresholds[0]); - allok = allok & check_tensor(tensors1[1], tensors2[1], maxT * C, "wpe", grad_thresholds[1]); + allok = allok & check_tensor(tensors1[1], tensors2[1], V * C, "wlmhead", grad_thresholds[1]); allok = allok & check_tensor(tensors1[2], tensors2[2], L * 3*C * C, "qkvw", grad_thresholds[2]); allok = allok & check_tensor(tensors1[3], tensors2[3], L * 3*C, "qkvb", grad_thresholds[3]); allok = allok & check_tensor(tensors1[4], tensors2[4], L * C * C, "attprojw", grad_thresholds[4]); @@ -292,9 +292,9 @@ int main(int argc, char *argv[]) { allok = allok & check_tensor(tensors1[15], tensors2[15], C, "lnfb", grad_thresholds[15]); } - float grad_norm = gpt2_calculate_grad_norm(&model, &multi_gpu_config); + float grad_norm = llama3_calculate_grad_norm(&model, &multi_gpu_config); float grad_scale = (grad_norm > 1.0f) ? 1.0f / grad_norm : 1.0f; - gpt2_update(&model, 1e-4f, 0.9f, 0.95f, 1e-8f, 0.0f, grad_scale, step+1, &multi_gpu_config); + llama3_update(&model, 1e-4f, 0.9f, 0.95f, 1e-8f, 0.0f, grad_scale, step+1, &multi_gpu_config); // print the timing information at the end printf("step %d: loss %f (took %f ms)\n", step+1, model.mean_loss, time_elapsed_s * 1000); @@ -329,32 +329,32 @@ int main(int argc, char *argv[]) { } // Finally, let's check determinism - gpt2_write_to_checkpoint(&model, "test_gpt2cu_model.ckpt"); + llama3_write_to_checkpoint(&model, "test_llama3cu_model.ckpt"); DataLoader loader; dataloader_init(&loader, "dev/data/tinyshakespeare/tiny_shakespeare_val.bin", B, T, multi_gpu_config.process_rank, multi_gpu_config.num_processes, 1); - save_state("test_gpt2cu_state.ckpt", 10, &model, &loader); + save_state("test_llama3cu_state.ckpt", 10, &model, &loader); int tokens[10]; for (int step = 0; step < 10; step++) { dataloader_next_batch(&loader); - gpt2_forward(&model, loader.inputs, B, T); - gpt2_backward_and_reduce(&model, loader.inputs, loader.targets, 1, 0); - gpt2_update(&model, 1e-4f, 0.9f, 0.95f, 1e-8f, 0.0f, 1.0f, step+11, &multi_gpu_config); + llama3_forward(&model, loader.inputs, B, T); + llama3_backward_and_reduce(&model, loader.inputs, loader.targets, 1, 0); + llama3_update(&model, 1e-4f, 0.9f, 0.95f, 1e-8f, 0.0f, 1.0f, step+11, &multi_gpu_config); losses[step] = model.mean_loss; tokens[step] = loader.inputs[0]; } // reload - gpt2_free(&model); - gpt2_build_from_checkpoint(&model, "test_gpt2cu_model.ckpt"); + llama3_free(&model); + llama3_build_from_checkpoint(&model, "test_llama3cu_model.ckpt"); int ld_step; - gpt2_allocate_state(&model, B, T); - load_state(&ld_step, &model, &loader, "test_gpt2cu_state.ckpt"); + llama3_allocate_state(&model, B, T); + load_state(&ld_step, &model, &loader, "test_llama3cu_state.ckpt"); for (int step = 0; step < 10; step++) { dataloader_next_batch(&loader); - gpt2_forward(&model, loader.inputs, B, T); - gpt2_backward_and_reduce(&model, loader.inputs, loader.targets, 1, 0); - gpt2_update(&model, 1e-4f, 0.9f, 0.95f, 1e-8f, 0.0f, 1.0f, step+11, &multi_gpu_config); + llama3_forward(&model, loader.inputs, B, T); + llama3_backward_and_reduce(&model, loader.inputs, loader.targets, 1, 0); + llama3_update(&model, 1e-4f, 0.9f, 0.95f, 1e-8f, 0.0f, 1.0f, step+11, &multi_gpu_config); if(loader.inputs[0] != tokens[step]) { printf("Nondeterminism! Token mismatch at step %d: %d vs %d\n", step, tokens[step], loader.inputs[0]); @@ -375,12 +375,12 @@ int main(int argc, char *argv[]) { printf("overall okay: %d\n", allok); // delete intermediate test files - remove("test_gpt2cu_model.ckpt"); - remove("test_gpt2cu_state.ckpt"); + remove("test_llama3cu_model.ckpt"); + remove("test_llama3cu_state.ckpt"); // free everything dataloader_free(&loader); - gpt2_free(&model); + llama3_free(&model); common_free(model); free(x); free(y); diff --git a/train_llama3.cu b/train_llama3.cu index d28a38983..d3369d0db 100644 --- a/train_llama3.cu +++ b/train_llama3.cu @@ -90,28 +90,28 @@ cudaStream_t main_stream; constexpr const size_t IO_BUF_SIZE = 32 * 1024 * 1024; // ---------------------------------------------------------------------------- -// GPT-2 model definition +// LLama-3 model definition typedef struct { int max_seq_len; // max sequence length, e.g. 1024 - int vocab_size; // vocab size, e.g. 50257 - int padded_vocab_size; // padded to e.g. %128==0, 50304 + int vocab_size; // vocab size, e.g. 128256 + int padded_vocab_size; // padded to e.g. %128==0, 128256 int num_layers; // number of layers, e.g. 12 - int num_heads; // number of query heads in attention, e.g. 12 - int num_kv_heads; // number of key and value heads in attention, e.g. 4 (<-- new in Llama 3) - int channels; // number of channels, e.g. 768 + int num_heads; // number of query heads in attention, e.g. 32 + int num_kv_heads; // number of key and value heads in attention, e.g. 8 (<-- new in Llama 3) + int channels; // number of channels, e.g. 2048 int multiple_of; // used in feedforward layer sizing, e.g. 1024 (<-- new in Llama 3) int use_scaled_rope; // whether to use scaled rope float ffn_dim_multiplier; // multiplier used in feedforward layer, e.g. 1.3 (<-- new in Llama 3) float norm_eps; // epsilon used in layernorm, e.g. 1e-5 float rope_theta; // theta used in ROPE attention, e.g. 500000.0 (<-- new in Llama 3) -} GPT2Config; +} LLama3Config; // the parameters of the model constexpr const int NUM_PARAMETER_TENSORS = 16; typedef struct { floatX* wte; // (V, C) - floatX* wpe; // (V, C) + floatX* wlmhead; // (V, C) floatX* ln1w; // (L, C) floatX* ln1b; // (L, C) floatX* qkvw; // (L, 3*C, C) @@ -129,7 +129,7 @@ typedef struct { } ParameterTensors; static_assert(sizeof(ParameterTensors) == NUM_PARAMETER_TENSORS * sizeof(void*), "Inconsistent sizes!"); -void fill_in_parameter_sizes(size_t* param_sizes, size_t* param_sizeof, GPT2Config config) { +void fill_in_parameter_sizes(size_t* param_sizes, size_t* param_sizeof, LLama3Config config) { // see train_llama3.py write_tensors() function for detailed docs of some of the trickery here // trick 1: all biases are still present but set to zero // trick 2: the SwiGLU weights are "packed" into one, concatenated @@ -185,7 +185,7 @@ void* malloc_and_point_parameters(ParameterTensors* params, size_t* param_elemen cudaCheck(cudaMalloc((void**)¶ms_memory, num_parameters_bytes)); // assign all the tensors their place in the array floatX** ptrs[] = { - ¶ms->wte, ¶ms->wpe, ¶ms->ln1w, ¶ms->ln1b, ¶ms->qkvw, ¶ms->qkvb, + ¶ms->wte, ¶ms->wlmhead, ¶ms->ln1w, ¶ms->ln1b, ¶ms->qkvw, ¶ms->qkvb, ¶ms->attprojw, ¶ms->attprojb, ¶ms->ln2w, ¶ms->ln2b, ¶ms->fcw, ¶ms->fcb, ¶ms->fcprojw, ¶ms->fcprojb, ¶ms->lnfw, ¶ms->lnfb }; @@ -247,7 +247,7 @@ struct TensorSpec { #define TENSOR_SPEC(pointer, size) TensorSpec{(void**)(&pointer), (size), dtype_of(pointer)}; -void fill_in_activation_sizes(const ActivationTensors* data, TensorSpec (&tensors)[NUM_ACTIVATION_TENSORS], size_t B, size_t T, GPT2Config config, int recompute) { +void fill_in_activation_sizes(const ActivationTensors* data, TensorSpec (&tensors)[NUM_ACTIVATION_TENSORS], size_t B, size_t T, LLama3Config config, int recompute) { const size_t Vp = config.padded_vocab_size; const size_t L = config.num_layers; const size_t NH = config.num_heads; @@ -322,7 +322,7 @@ void* malloc_and_point_activations(TensorSpec (&tensors)[NUM_ACTIVATION_TENSORS] } typedef struct { - GPT2Config config; + LLama3Config config; // the weights of the model, and their sizes ParameterTensors params; size_t param_elements[NUM_PARAMETER_TENSORS]; @@ -350,7 +350,7 @@ typedef struct { float* accumulated_mean_loss; // GPU buffer used to accumulate loss across micro-steps float* cpu_losses; // CPU buffer to copy the losses to, allocated with cudaMallocHost unsigned long long rng_state; // the RNG state for seeding stochastic rounding etc. - unsigned long long rng_state_last_update; // RNG before last gpt2_update() to re-round identically from master weights + unsigned long long rng_state_last_update; // RNG before last llama3_update() to re-round identically from master weights int use_master_weights; // keep master weights copy in float for optim update? 0|1 bool init_state; // set to true if master weights need to be initialized int gelu_fusion; // fuse gelu via cuBLASLt (0=none, 1=forward, 2=forward+backward) @@ -359,9 +359,9 @@ typedef struct { int* workload_indices; // encoder_backward, B*T*num_c_groups (int) int4* bucket_info; // encoder_backward, B*T*num_c_groups (int4) - size for worst case floatX* freqs_cis; // (T, hd) for RoPE -} GPT2; +} LLama3; -void gpt2_init_common(GPT2 *model) { +void llama3_init_common(LLama3 *model) { // common inits outside of the model weights // memory lazily initialized in forward() model->acts_memory = NULL; @@ -391,7 +391,7 @@ void gpt2_init_common(GPT2 *model) { model->freqs_cis = NULL; } -void gpt2_allocate_weights(GPT2 *model) { +void llama3_allocate_weights(LLama3 *model) { // fill in all the parameter tensor dimensions and types fill_in_parameter_sizes(model->param_elements, model->param_sizeof, model->config); model->num_parameters = 0; @@ -415,7 +415,7 @@ void gpt2_allocate_weights(GPT2 *model) { model->params_memory = malloc_and_point_parameters(&model->params, model->param_elements, model->param_sizeof); } -void gpt2_allocate_state(GPT2 *model, int B, int T) { +void llama3_allocate_state(LLama3 *model, int B, int T) { printf0("allocating %d MiB for parameter gradients\n", (int)round(model->num_parameters * sizeof(floatX) / (1024 * 1024))); assert(model->grads_memory == nullptr); model->grads_memory = malloc_and_point_parameters(&model->grads, model->param_elements, model->param_sizeof); @@ -487,7 +487,7 @@ void gpt2_allocate_state(GPT2 *model, int B, int T) { printf0(" -> estimated maximum batch size: %zu\n", B + free / bytes_per_sequence); } -void gpt2_write_to_checkpoint(GPT2 *model, const char* checkpoint_path) { +void llama3_write_to_checkpoint(LLama3 *model, const char* checkpoint_path) { // write the model to a checkpoint file printf0("Writing model to %s\n", checkpoint_path); FILE *model_file = fopenCheck(checkpoint_path, "wb"); @@ -511,7 +511,7 @@ void gpt2_write_to_checkpoint(GPT2 *model, const char* checkpoint_path) { fcloseCheck(model_file); } -void gpt2_build_from_checkpoint(GPT2 *model, const char* checkpoint_path, bool weight_init=true) { +void llama3_build_from_checkpoint(LLama3 *model, const char* checkpoint_path, bool weight_init=true) { // If weight_init is true, we will load the weights from this checkpoint .bin file // We sometimes want this to be false, if we are going to initialize these weights from // the master weights that are instead stored in the state .bin file. @@ -539,7 +539,7 @@ void gpt2_build_from_checkpoint(GPT2 *model, const char* checkpoint_path, bool w // 3 = fp32, padded vocab // 5 = bf16, padded vocab, layernorms also in bf16 fprintf(stderr, "Bad version in model file\n"); - fprintf(stderr, "---> HINT: try to re-run `python train_gpt2.py`\n"); + fprintf(stderr, "---> HINT: try to re-run `python train_llama3.py`\n"); exit(EXIT_FAILURE); } @@ -552,7 +552,7 @@ void gpt2_build_from_checkpoint(GPT2 *model, const char* checkpoint_path, bool w } if (PRECISION_MODE == PRECISION_FP32 && version != 3) { fprintf(stderr, "Precision is configured as FP32 but model at %s is not.\n", checkpoint_path); - fprintf(stderr, "---> HINT: to turn on FP32 you have to compile like: `make train_gpt2cu PRECISION=FP32`\n"); + fprintf(stderr, "---> HINT: to turn on FP32 you have to compile like: `make train_llama3cu PRECISION=FP32`\n"); fprintf(stderr, "---> HINT: are you sure you're loading a .bin file without any _bf16 in the name?\n"); exit(EXIT_FAILURE); } @@ -597,7 +597,7 @@ void gpt2_build_from_checkpoint(GPT2 *model, const char* checkpoint_path, bool w // ------------------------------------------------------------------------ // allocate memory for the model parameters - gpt2_allocate_weights(model); + llama3_allocate_weights(model); // read in the parameters if weight_init is true if (weight_init) { @@ -612,7 +612,7 @@ void gpt2_build_from_checkpoint(GPT2 *model, const char* checkpoint_path, bool w // propagate inputs through the network to produce logits. // right now, this function is fully synchronous with the host -void gpt2_forward(GPT2 *model, const int* inputs, size_t B, size_t T) { +void llama3_forward(LLama3 *model, const int* inputs, size_t B, size_t T) { NVTX_RANGE_FN(); // we must be careful and use size_t instead of int, otherwise // we could overflow int. E.g. l * B * NH * T * T overflows int at B 16. @@ -657,7 +657,7 @@ void gpt2_forward(GPT2 *model, const int* inputs, size_t B, size_t T) { ParameterTensors params = model->params; // for brevity ActivationTensors acts = model->acts; encoder_forward(acts.encoded, model->inputs, params.wte, NULL, B, T, C, main_stream); // encoding goes into residual[0] - // first layernorm isn't fused + // first rmsnorm isn't fused rmsnorm_forward((model->recompute < 2) ? acts.ln1 : acts.lnf, acts.ln1_rstd, acts.encoded, params.ln1w, B, T, C, main_stream); for (int l = 0; l < L; l++) { @@ -686,7 +686,7 @@ void gpt2_forward(GPT2 *model, const int* inputs, size_t B, size_t T) { floatX* l_fch = acts.fch + l * B * T * ffn_channels; // reuse the same activation buffer at each layer, as we'll re-compute the gelu during backward // very useful because we dramatically reduce VRAM usage, and may be able to fit larger batch size - floatX* l_fch_gelu = (model->recompute < 1) ? acts.fch_gelu + l * B * T * ffn_channels_post_gelu : acts.fch_gelu; + floatX* l_fch_swiglu = (model->recompute < 1) ? acts.fch_gelu + l * B * T * ffn_channels_post_gelu : acts.fch_gelu; floatX* l_residual3 = acts.residual3 + l * B * T * C; floatX* scratch = (floatX*)acts.output; // used for non-cudnn attention, fcproj, attproj, etc. floatX* qkv_rep_scratch = (floatX*)acts.scratch_bt4c; // we can use the BT4C scratch for qkv replication @@ -715,8 +715,8 @@ void gpt2_forward(GPT2 *model, const int* inputs, size_t B, size_t T) { matmul_forward_cublaslt(scratch, l_atty, l_attprojw, l_attprojb, B, T, C, C, main_stream); fused_residual_rmsnorm_forward5(l_residual2, l_ln2, l_ln2_rstd, residual, scratch, l_ln2w, B*T, C, main_stream); matmul_forward_cublaslt(l_fch, l_ln2, l_fcw, l_fcb, B, T, C, ffn_channels, main_stream); - swiglu_forward(l_fch_gelu, l_fch, B, T, ffn_channels_post_gelu, main_stream); - matmul_forward_cublaslt(scratch, l_fch_gelu, l_fcprojw, l_fcprojb, B, T, ffn_channels_post_gelu, C, main_stream); + swiglu_forward(l_fch_swiglu, l_fch, B, T, ffn_channels_post_gelu, main_stream); + matmul_forward_cublaslt(scratch, l_fch_swiglu, l_fcprojw, l_fcprojb, B, T, ffn_channels_post_gelu, C, main_stream); // OK, fusion across blocks. if(l+1 != L) { @@ -729,7 +729,7 @@ void gpt2_forward(GPT2 *model, const int* inputs, size_t B, size_t T) { } } - matmul_forward_cublaslt(acts.output, acts.lnf, params.wpe, NULL, B, T, C, Vp, main_stream); + matmul_forward_cublaslt(acts.output, acts.lnf, params.wlmhead, NULL, B, T, C, Vp, main_stream); cudaCheck(cudaDeviceSynchronize()); } @@ -737,10 +737,10 @@ void gpt2_forward(GPT2 *model, const int* inputs, size_t B, size_t T) { // Forwards both the model and the loss and is used for validation splits and evals. // In particular it populates cpu_losses with loss at each token. // Some of the evals (e.g. HellaSwag) require the per-token losses, which are produced here. -float gpt2_validate(GPT2 *model, const int* inputs, const int* targets, size_t B, size_t T) { +float llama3_validate(LLama3 *model, const int* inputs, const int* targets, size_t B, size_t T) { assert(targets != NULL); // forward the model itself - gpt2_forward(model, inputs, B, T); + llama3_forward(model, inputs, B, T); // convenience shortcuts, size_t instead of int so that pointer arithmetics don't overflow const size_t V = model->config.vocab_size; const size_t Vp = model->config.padded_vocab_size; @@ -764,7 +764,7 @@ float gpt2_validate(GPT2 *model, const int* inputs, const int* targets, size_t B return mean_loss; } -void gpt2_backward_and_reduce(GPT2 *model, int* inputs, const int* targets, int grad_accum_steps, int micro_step) { +void llama3_backward_and_reduce(LLama3 *model, int* inputs, const int* targets, int grad_accum_steps, int micro_step) { if(model->grads_memory == nullptr) { fprintf(stderr, "Need to allocate gradients before backward"); exit(EXIT_FAILURE); @@ -825,7 +825,7 @@ void gpt2_backward_and_reduce(GPT2 *model, int* inputs, const int* targets, int // technically that is a small, inline backward() pass of calculating // total, final loss as the mean over all losses over all (B,T) positions in the batch // next: backward the classifier matmul - matmul_backward(model->acts.scratch_bt4c, grads.wpe, NULL, acts.output, acts.lnf, params.wpe, NULL, B, T, C, Vp, main_stream); + matmul_backward(model->acts.scratch_bt4c, grads.wlmhead, NULL, acts.output, acts.lnf, params.wlmhead, NULL, B, T, C, Vp, main_stream); // backward the final layernorm floatX* residual = acts.residual3 + (L-1) * B * T * C; // last residual is in residual3 rmsnorm_backward(dresidual, grads.lnfw, scratchF, model->acts.scratch_bt4c, residual, params.lnfw, acts.lnf_rstd, B, T, C, main_stream); @@ -868,7 +868,7 @@ void gpt2_backward_and_reduce(GPT2 *model, int* inputs, const int* targets, int floatX* l_ln2 = (model->recompute < 2) ? acts.ln2 + l * B * T * C : acts.lnf; float* l_ln2_rstd = acts.ln2_rstd + l * B * T; floatX* l_fch_pre_gelu = acts.fch + l * B * T * ffn_channels; - floatX* l_fch_gelu = (model->recompute < 1) ? acts.fch_gelu + l * B * T * ffn_channels_post_gelu : acts.fch_gelu; + floatX* l_fch_swiglu = (model->recompute < 1) ? acts.fch_gelu + l * B * T * ffn_channels_post_gelu : acts.fch_gelu; // get the pointers of the gradients of the activations for this layer // notice that there is no l *, because we just have a single copy, and keep // re-using this memory in every Transformer block as we calculate backward pass @@ -879,12 +879,11 @@ void gpt2_backward_and_reduce(GPT2 *model, int* inputs, const int* targets, int // start the backward pass for this layer if(model->recompute >= 1) { // recompute >= 1 means we recompute gelu. in this case, - // l_fch_gelu is just a buffer, so re-compute the gelu from l_fch here - // gelu_forward(l_fch_gelu, l_fch_pre_gelu, B*T*4*C, main_stream); - swiglu_forward(l_fch_gelu, l_fch_pre_gelu, B, T, ffn_channels_post_gelu, main_stream); + // l_fch_swiglu is just a buffer, so re-compute the gelu from l_fch here + swiglu_forward(l_fch_swiglu, l_fch_pre_gelu, B, T, ffn_channels_post_gelu, main_stream); } // backward the 2nd matmul of MLP - matmul_backward(dl_bt4c, dl_fcprojw, dl_fcprojb, dresidual, l_fch_gelu, l_fcprojw, scratchF, B, T, ffn_channels_post_gelu, C, main_stream); + matmul_backward(dl_bt4c, dl_fcprojw, dl_fcprojb, dresidual, l_fch_swiglu, l_fcprojw, nullptr, B, T, ffn_channels_post_gelu, C, main_stream); // backward the swiglu here, use scratchX to hold the grad because SwiGLU can't be inplace swiglu_backward(dl_bt4c2, dl_bt4c, l_fch_pre_gelu, B, T, ffn_channels_post_gelu, main_stream); // backward the 1st matmul of MLP @@ -956,7 +955,7 @@ void gpt2_backward_and_reduce(GPT2 *model, int* inputs, const int* targets, int #endif cudaCheck(cudaMemcpyAsync(&model->mean_loss, model->accumulated_mean_loss, sizeof(float), cudaMemcpyDeviceToHost, main_stream)); // reduce the gradients for non-transformer block parameters - floatX* const pointers[] = {grads.wte, grads.wpe, grads.lnfw, grads.lnfb}; + floatX* const pointers[] = {grads.wte, grads.wlmhead, grads.lnfw, grads.lnfb}; const size_t nelem[] = {Vp * C, Vp * C, C, C}; multi_gpu_async_reduce_gradient(pointers, nelem, &multi_gpu_config, main_stream); } @@ -969,9 +968,9 @@ void gpt2_backward_and_reduce(GPT2 *model, int* inputs, const int* targets, int } } -// Gets the offset of a specific tensor for a specific layer in the GPT2 model +// Gets the offset of a specific tensor for a specific layer in the LLama3 model // layer_id is ignored for weights that are not part of a transformer block -ShardInfo gpt2_get_tensor_at_layer(const GPT2 *model, int layer_id, int param_tensor_id) { +ShardInfo llama3_get_tensor_at_layer(const LLama3 *model, int layer_id, int param_tensor_id) { // first offset our way to the parameter tensor start ptrdiff_t offset = 0; for (int i = 0; i < param_tensor_id; i++) { @@ -986,7 +985,7 @@ ShardInfo gpt2_get_tensor_at_layer(const GPT2 *model, int layer_id, int param_te return {offset, size}; } -float gpt2_calculate_grad_norm(GPT2 *model, MultiGpuConfig* multi_gpu_config) { +float llama3_calculate_grad_norm(LLama3 *model, MultiGpuConfig* multi_gpu_config) { NVTX_RANGE_FN(); floatX* grads_memory = (floatX*)model->grads_memory; @@ -1001,7 +1000,7 @@ float gpt2_calculate_grad_norm(GPT2 *model, MultiGpuConfig* multi_gpu_config) { // grads_memory only contains the averaged gradients at the local shards, // so we only calculate the grad norm at the grads_memory belonging to the local shards for (int i = 0; i < NUM_PARAMETER_TENSORS; i++) { - ShardInfo tensor = gpt2_get_tensor_at_layer(model, 0, i); + ShardInfo tensor = llama3_get_tensor_at_layer(model, 0, i); ShardInfo shard = multi_gpu_get_shard_offset(tensor.size, multi_gpu_config, 1); ptrdiff_t offset = tensor.offset + shard.offset; bool is_first_pass = (i == 0); @@ -1029,8 +1028,8 @@ float gpt2_calculate_grad_norm(GPT2 *model, MultiGpuConfig* multi_gpu_config) { return grad_norm_cpu; } -void gpt2_update(GPT2 *model, float learning_rate, float beta1, float beta2, float eps, float weight_decay, float grad_scale, int t, - MultiGpuConfig* multi_gpu_config, bool init_from_master_only=false) { +void llama3_update(LLama3 *model, float learning_rate, float beta1, float beta2, float eps, float weight_decay, float grad_scale, int t, + MultiGpuConfig* multi_gpu_config, bool init_from_master_only=false) { // update the model parameters using the AdamW optimizer // keep in mind that optimizer sharding (ZeRO-1) assigns different parameters to different GPUs // so we may not be responsible for the entire parameter tensor @@ -1065,7 +1064,7 @@ void gpt2_update(GPT2 *model, float learning_rate, float beta1, float beta2, flo num_layers = 1; } - ShardInfo tensor = gpt2_get_tensor_at_layer(model, 0, i); + ShardInfo tensor = llama3_get_tensor_at_layer(model, 0, i); ShardInfo shard = multi_gpu_get_shard_offset(tensor.size, multi_gpu_config, 1); ptrdiff_t local_offset_full = tensor.offset + shard.offset; ptrdiff_t local_offset_partial = tensor.offset / multi_gpu_config->num_processes; @@ -1120,7 +1119,7 @@ void gpt2_update(GPT2 *model, float learning_rate, float beta1, float beta2, flo cudaCheck(cudaDeviceSynchronize()); } -float gpt2_estimate_mfu(GPT2 *model, int num_tokens, float dt) { +float llama3_estimate_mfu(LLama3 *model, int num_tokens, float dt) { /* Estimate model flops utilization (MFU) ref: Section 2.1 of https://arxiv.org/pdf/2001.08361 @@ -1149,7 +1148,7 @@ float gpt2_estimate_mfu(GPT2 *model, int num_tokens, float dt) { return mfu; } -void gpt2_free(GPT2 *model) { +void llama3_free(LLama3 *model) { cudaFreeCheck(&model->params_memory); cudaFreeCheck(&model->grads_memory); cudaFreeCheck(&model->m_memory); @@ -1193,7 +1192,7 @@ void common_start(bool override_enable_tf32 = true, bool print_device_info = tru #endif } -void common_free(GPT2 &model) { +void common_free(LLama3 &model) { cudaCheck(cudaStreamDestroy(main_stream)); cudaCheck(cudaFree(cublaslt_workspace)); cublasCheck(cublasLtDestroy(cublaslt_handle)); @@ -1203,7 +1202,7 @@ void common_free(GPT2 &model) { } -void save_state(const char* filename, int step, GPT2* model, DataLoader* loader) { +void save_state(const char* filename, int step, LLama3* model, DataLoader* loader) { printf("Writing state to %s\n", filename); FILE *state_file = fopenCheck(filename, "wb"); int state_header[256]; @@ -1219,7 +1218,7 @@ void save_state(const char* filename, int step, GPT2* model, DataLoader* loader) state_header[10] = step; // step of the optimization // model rng state, start at 20 to leave some padding *((unsigned long long*)&state_header[20]) = model->rng_state; // random number generator state - *((unsigned long long*)&state_header[22]) = model->rng_state_last_update; // last gpt2_update + *((unsigned long long*)&state_header[22]) = model->rng_state_last_update; // last llama3_update // dataloader state, start at 30 to leave some padding *((size_t*)&state_header[30]) = loader->current_shard_idx; // shard of the dataset *((size_t*)&state_header[32]) = loader->current_sample_idx; // position in shard @@ -1244,7 +1243,7 @@ void save_state(const char* filename, int step, GPT2* model, DataLoader* loader) fcloseCheck(state_file); } -void load_state(int* step, GPT2* model, DataLoader* loader, const char* filename) { +void load_state(int* step, LLama3* model, DataLoader* loader, const char* filename) { FILE *state_file = fopenCheck(filename, "rb"); int state_header[256]; freadCheck(state_header, sizeof(int), 256, state_file); @@ -1256,7 +1255,7 @@ void load_state(int* step, GPT2* model, DataLoader* loader, const char* filename int should_shuffle = state_header[5]; // shuffle state of the dataloader *step = state_header[10]; // step of the optimization model->rng_state = *((unsigned long long*)&state_header[20]); // random number generator state - model->rng_state_last_update = *((unsigned long long*)&state_header[22]); // last gpt2_update + model->rng_state_last_update = *((unsigned long long*)&state_header[22]); // last llama3_update size_t current_shard_idx = *((size_t*)&state_header[30]); // shard index size_t current_sample_idx = *((size_t*)&state_header[32]); // position in shard @@ -1280,7 +1279,7 @@ void load_state(int* step, GPT2* model, DataLoader* loader, const char* filename file_to_device(model->master_weights, state_file, shard_num_parameters * sizeof(float), IO_BUF_SIZE, main_stream); // restore weights from the master weights using the RNG state before last weight update model->rng_state = model->rng_state_last_update; - gpt2_update(model, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0, &multi_gpu_config, /* init_from_master_only*/ true); + llama3_update(model, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0, &multi_gpu_config, /* init_from_master_only*/ true); model->rng_state = *((unsigned long long*)&state_header[20]); // use final RNG state from checkpoint after this } @@ -1310,14 +1309,14 @@ void load_state(int* step, GPT2* model, DataLoader* loader, const char* filename fcloseCheck(state_file); } -void write_checkpoint(const char* output_log_dir, int step, GPT2* model, DataLoader* train_loader, MultiGpuConfig* multi_gpu_config) { +void write_checkpoint(const char* output_log_dir, int step, LLama3* model, DataLoader* train_loader, MultiGpuConfig* multi_gpu_config) { // a checkpoint contains: model weights, optimizer/dataloader state, and a DONE file printf0("Writing checkpoint at step %d\n", step); int rank = multi_gpu_config->process_rank; // only rank 0 writes the model file because it is the same across all ranks if (rank == 0) { snprintf(filename_buffer, sizeof(filename_buffer), "%s/model_%08d.bin", output_log_dir, step); - gpt2_write_to_checkpoint(model, filename_buffer); + llama3_write_to_checkpoint(model, filename_buffer); } // all ranks write their state file snprintf(filename_buffer, sizeof(filename_buffer), "%s/state_%08d_%05d.bin", output_log_dir, step, rank); @@ -1348,7 +1347,7 @@ void delete_checkpoint(const char* output_log_dir, int step, MultiGpuConfig* mul } #ifndef TESTING -// if we are TESTING (see test_gpt2.cu), we'll skip everything below this point +// if we are TESTING (see test_llama3.cu), we'll skip everything below this point // ---------------------------------------------------------------------------- // training resumption logic, very useful when jobs crash once in a while @@ -1360,12 +1359,12 @@ void delete_checkpoint(const char* output_log_dir, int step, MultiGpuConfig* mul // (all single letters have been claimed now) void error_usage() { - fprintf(stderr, "Usage: ./train_gpt2cu [options]\n"); + fprintf(stderr, "Usage: ./train_llama3cu [options]\n"); fprintf(stderr, "Options:\n"); // file system input / output fprintf(stderr, " -i train data filename pattern (default = dev/data/tinyshakespeare/tiny_shakespeare_train.bin)\n"); fprintf(stderr, " -j val data filename pattern (default = dev/data/tinyshakespeare/tiny_shakespeare_val.bin)\n"); - fprintf(stderr, " -e input .bin filename or descriptor, see code comments as docs. (default = gpt2_124M_bf16.bin)\n"); + fprintf(stderr, " -e input .bin filename or descriptor, see code comments as docs. (default = llama3.2_1B_bf16.bin)\n"); fprintf(stderr, " -o output log dir (default = NULL, no logging)\n"); fprintf(stderr, " -lg log gpu info every x steps (default = -1; disabled)\n"); fprintf(stderr, " -n write optimization checkpoints every how many steps? (default 0, don't)\n"); @@ -1417,7 +1416,7 @@ int main(int argc, char *argv[]) { // read in the (optional) command line arguments const char* train_data_pattern = "dev/data/tinyshakespeare/tiny_shakespeare_train.bin"; const char* val_data_pattern = "dev/data/tinyshakespeare/tiny_shakespeare_val.bin"; - const char* load_filename = "llama3.1_8B.bin"; // bf16 weights of the Llama 3.1 8B model + const char* load_filename = "llama3.2_1B_bf16.bin"; // bf16 weights of the Llama 3.2 1B model const char* lr_scheduler_type = "cosine"; const char* output_log_dir = NULL; int checkpoint_every = 0; // write checkpoints every how many steps? @@ -1510,12 +1509,12 @@ int main(int argc, char *argv[]) { // calculate sensible default for total batch size as assuming no gradient accumulation if (total_batch_size == -1) { total_batch_size = tokens_per_fwdbwd; } // in the future, we might want to set gelu fusion to 2 for SM90+ and 0 for other GPUs - if (gelu_fusion == -1) { gelu_fusion = 0; } // (deviceProp.major >= 9) ? 2 : 0; } // in gpt2_init_common for test_gpt2cu... + if (gelu_fusion == -1) { gelu_fusion = 0; } // (deviceProp.major >= 9) ? 2 : 0; } // in llama3_init_common for test_llama3cu... // calculate the number of gradient accumulation steps from the desired total batch size assert(total_batch_size % tokens_per_fwdbwd == 0); int grad_accum_steps = total_batch_size / tokens_per_fwdbwd; // if we're only overfitting a single batch for debugging, let's overfit the first batch - // from val instead of train split, because val is smaller and faster. (train_gpt2.py does the same) + // from val instead of train split, because val is smaller and faster. (train_llama3.py does the same) if (overfit_single_batch == 1) { train_data_pattern = val_data_pattern; } printf0("+-----------------------+----------------------------------------------------+\n"); printf0("| Parameter | Value |\n"); @@ -1566,16 +1565,16 @@ int main(int argc, char *argv[]) { } // build the GPT-2 model - GPT2 model; - gpt2_init_common(&model); + LLama3 model; + llama3_init_common(&model); if (resuming == 1) { // if `-y 1` was set, then we are resuming from the latest checkpoint // if we are using master weights, we'll init them later inside load_state() bool weight_init = !use_master_weights; - gpt2_build_from_checkpoint(&model, filename_buffer, weight_init); + llama3_build_from_checkpoint(&model, filename_buffer, weight_init); } else if (ends_with_bin(load_filename)) { // otherwise, if this is a .bin file, we assume it's a model, let's init from it - gpt2_build_from_checkpoint(&model, load_filename); + llama3_build_from_checkpoint(&model, load_filename); } else { // For Llama 3.1 we currently demand a .bin file to load the model from, and // initializing from scratch is currently not supported (but can be added later) @@ -1644,7 +1643,7 @@ int main(int argc, char *argv[]) { printf0("HellaSwag eval not found at %s, skipping its evaluation\n", hellaswag_path); printf0("You can run `python dev/data/hellaswag.py` to export and use it with `-h 1`.\n"); } - // more prints related to allocations from gpt2_build_from_checkpoint down here to not mess up our table above + // more prints related to allocations from llama3_build_from_checkpoint down here to not mess up our table above printf0("num_parameters: %zu => bytes: %zu\n", model.num_parameters, model.num_parameters_bytes); printf0("allocated %d MiB for model parameters\n", (int)round(model.num_parameters_bytes / (1024 * 1024))); // few more prints for gradient accumulation math up above @@ -1659,7 +1658,7 @@ int main(int argc, char *argv[]) { // set up the Tokenizer Tokenizer tokenizer; - // tokenizer_init(&tokenizer, "gpt2_tokenizer.bin"); // TODO: port tokenizer later from GPT2 -> Llama 3 + // tokenizer_init(&tokenizer, "llama3_tokenizer.bin"); // TODO: port tokenizer later from GPT2 -> Llama 3 // set up learning rate scheduler LearningRateScheduler lr_scheduler; @@ -1673,7 +1672,7 @@ int main(int argc, char *argv[]) { // if we found a checkpoint to resume from, load the optimization state int step = 0; - gpt2_allocate_state(&model, B, T); + llama3_allocate_state(&model, B, T); if (resuming == 1) { snprintf(filename_buffer, sizeof(filename_buffer), "%s/state_%08d_%05d.bin", output_log_dir, resume_max_step, multi_gpu_config.process_rank); load_state(&step, &model, &train_loader, filename_buffer); @@ -1705,7 +1704,7 @@ int main(int argc, char *argv[]) { dataloader_reset(&val_loader); for (int i = 0; i < val_num_batches; i++) { dataloader_next_batch(&val_loader); - val_loss += gpt2_validate(&model, val_loader.inputs, val_loader.targets, B, T); + val_loss += llama3_validate(&model, val_loader.inputs, val_loader.targets, B, T); } val_loss /= val_num_batches; val_loss = multi_gpu_cpu_float_sum(val_loss, &multi_gpu_config) / multi_gpu_config.num_processes; @@ -1722,7 +1721,7 @@ int main(int argc, char *argv[]) { for (int i = 0; i < eval_loader.num_batches; i++) { if (i % 10 == 0) { printf("evaluating HellaSwag: %d/%d\r", i, eval_loader.num_batches); } evalloader_next_batch(&eval_loader); - gpt2_validate(&model, eval_loader.inputs, eval_loader.targets, B, T); + llama3_validate(&model, eval_loader.inputs, eval_loader.targets, B, T); int correct = evalloader_stat_losses(&eval_loader, model.cpu_losses); eval_acc_norm += (float)correct; } @@ -1753,7 +1752,7 @@ int main(int argc, char *argv[]) { // on cuDNN 9.2.1 with cuDNN FrontEnd 1.5.2, T >= 256 seems bit-for-bit identical // (but even if it wasn't fully identical that's probably not the end of the world) // note this is still somewhat wasteful because we don't have a KV cache! - gpt2_forward(&model, gen_tokens, 1, CEIL_DIV(t, min(T,256)) * min(T,256)); + llama3_forward(&model, gen_tokens, 1, CEIL_DIV(t, min(T,256)) * min(T,256)); // get the V-dimensional vector probs[0, t-1, :] floatX* logits = model.acts.output + (t - 1) * model.config.padded_vocab_size; // move probs back to CPU and sample (note we only move the first vocab_size logits, ignoring the padding) @@ -1815,15 +1814,15 @@ int main(int argc, char *argv[]) { // fetch the next data batch dataloader_next_batch(&train_loader); // forward pass. note that we pass in grad_accum_steps, which scales down the loss - gpt2_forward(&model, train_loader.inputs, B, T); + llama3_forward(&model, train_loader.inputs, B, T); // backward pass. all model params accumulate gradients with += inside this inner loop - gpt2_backward_and_reduce(&model, train_loader.inputs, train_loader.targets, grad_accum_steps, micro_step); + llama3_backward_and_reduce(&model, train_loader.inputs, train_loader.targets, grad_accum_steps, micro_step); } float zloss = (float)(update_detector(&loss_outlier_detector, (double)model.mean_loss)); // loss z-score // fetch the next learning rate float step_learning_rate = get_learning_rate(&lr_scheduler, step); // calculate the gradient norm and how much we wish to scale the gradient - float grad_norm = gpt2_calculate_grad_norm(&model, &multi_gpu_config); + float grad_norm = llama3_calculate_grad_norm(&model, &multi_gpu_config); float zgrad = (float)(update_detector(&grad_norm_outlier_detector, (double)grad_norm)); // grad z-score // update the model parameters if (isfinite(zloss) && skip_update_lossz != 0.0f && zloss > skip_update_lossz) { @@ -1834,7 +1833,7 @@ int main(int argc, char *argv[]) { // clip the gradient norm to a maximum value float grad_clip = 1.0f; float grad_scale = (grad_norm > grad_clip) ? grad_clip / grad_norm : 1.0f; - gpt2_update(&model, step_learning_rate, 0.9f, 0.95f, 1e-8f, weight_decay, grad_scale, step+1, &multi_gpu_config); + llama3_update(&model, step_learning_rate, 0.9f, 0.95f, 1e-8f, weight_decay, grad_scale, step+1, &multi_gpu_config); } cudaCheck(cudaEventRecord(end)); cudaCheck(cudaEventSynchronize(end)); // wait for the end event to finish to get correct timings @@ -1853,7 +1852,7 @@ int main(int argc, char *argv[]) { ema_tokens_per_second = 0.95f * ema_tokens_per_second + 0.05f * tokens_per_second; bias_corrected_ema_tokens_per_second = ema_tokens_per_second / (1.0f - powf(0.95f, step)); } - float mfu = gpt2_estimate_mfu(&model, B * T * grad_accum_steps, time_elapsed_ms / 1000.0f); + float mfu = llama3_estimate_mfu(&model, B * T * grad_accum_steps, time_elapsed_ms / 1000.0f); printf0("step %4d/%d | loss %7.6f (%+.2fz)| norm %6.4f (%+.2fz)| lr %.2e | %.2f ms | %.1f%% bf16 MFU | %.0f tok/s\n", step + 1, train_num_batches, model.mean_loss, zloss, grad_norm, zgrad, step_learning_rate, time_elapsed_ms, 100*mfu, bias_corrected_ema_tokens_per_second); @@ -1882,7 +1881,7 @@ int main(int argc, char *argv[]) { free(cpu_logits); free(gen_tokens); multi_gpu_config_free(&multi_gpu_config); - gpt2_free(&model); + llama3_free(&model); common_free(model); return 0; } From 090341e8dcf1e62593f45448924649aba8e948d1 Mon Sep 17 00:00:00 2001 From: Erik Schultheis Date: Sun, 13 Apr 2025 17:55:15 +0200 Subject: [PATCH 43/63] enable llama3 CI --- .github/workflows/ci.yml | 15 +++++++++++ .github/workflows/ci_gpu.yml | 50 +++++++++++++++++++++++++++++++++++- train_llama3.py | 6 +++-- 3 files changed, 68 insertions(+), 3 deletions(-) diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 27ebad62e..4840d30e4 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -219,3 +219,18 @@ jobs: - name: Build project run: make -j4 -C dev/cuda + + build-llama3: + runs-on: ubuntu-latest + container: + image: nvidia/cuda:12.4.1-devel-ubuntu22.04 + + steps: + - name: Checkout code + uses: actions/checkout@v4 + + - name: Build FP32 + run: PRECISION=FP32 make test_llama3cu train_llama3cu + + - name: Build BF16 + run: PRECISION=BF16 make test_llama3cu train_llama3cu diff --git a/.github/workflows/ci_gpu.yml b/.github/workflows/ci_gpu.yml index f4a9dfb4b..782ba73d2 100644 --- a/.github/workflows/ci_gpu.yml +++ b/.github/workflows/ci_gpu.yml @@ -9,9 +9,10 @@ on: pull_request: branches: - master + - llama3 jobs: - build-and-test-gpu: + build-and-test-gpt2: runs-on: ubicloud-gpu-standard-1-latest steps: @@ -117,6 +118,53 @@ jobs: - name: Execute testing program fp32 with cuDNN run: ./test_gpt2fp32cu + build-and-test-llama3: + runs-on: ubicloud-gpu-standard-1-latest + + steps: + - name: Checkout code + uses: actions/checkout@v4 + + - name: Install OpenMP + run: sudo apt-get update && sudo apt-get install -y libomp-dev + + - name: Install dependencies + run: pip install -r requirements.txt + + - name: Run preprocessing + run: python dev/data/tinyshakespeare.py --model_desc llama-3 + + - name: Train model + # modle is too big to fit on our runner GPUs. Thus, we run it on cpu :( + run: python train_llama3.py --write_tensors 1 --dtype float32 --device cpu + + - name: Build FP32 precision + run: PRECISION=FP32 make test_llama3cu + + - name: Run default + run: ./test_llama3cu + + - name: Run no recompute GeLU + run: ./test_llama3cu -r 0 + + - name: Run recompute LN + run: ./test_llama3cu -r 2 + + - name: Build BF16 precision + run: PRECISION=BF16 make train_llama3cu test_llama3cu + + - name: Run default + run: ./test_llama3cu + + - name: Run no recompute GeLU + run: ./test_llama3cu -r 0 + + - name: Run no master weights + run: ./test_llama3cu -w 0 + + - name: Run recompute LN + run: ./test_llama3cu -r 2 + unit-tests-gpu: runs-on: ubicloud-gpu-standard-1-latest diff --git a/train_llama3.py b/train_llama3.py index 8caaaf9f7..e99e5f84d 100644 --- a/train_llama3.py +++ b/train_llama3.py @@ -317,12 +317,13 @@ def __init__(self, config): self.init_rng = torch.Generator() self.init_rng.manual_seed(42) - self.freqs_cis = precompute_freqs_cis( + freqs_cis = precompute_freqs_cis( config.n_embd // config.n_head, config.block_size * 2, config.rope_theta, config.use_scaled_rope, ) + self.register_buffer('freqs_cis', freqs_cis, persistent=False) def forward(self, idx, targets=None, return_logits=True, start_pos=0): _, t = idx.size() @@ -1095,7 +1096,7 @@ def print0(*args, **kwargs): elif hasattr(torch.backends, "mps") and torch.backends.mps.is_available(): device = "mps" device_type = 'cuda' if 'cuda' in device else 'cpu' - assert device_type in {'cuda'}, "GPU required to run LLaMA 3" # we need to load LLaMA as bf16 on CUDA + #assert device_type in {'cuda'}, "GPU required to run LLaMA 3" # we need to load LLaMA as bf16 on CUDA print(f"using device: {device}") # calculate gradient accumulation from the desired total batch size and the current run configuration @@ -1131,6 +1132,7 @@ def print0(*args, **kwargs): if args.dtype == "float32": model = model.to(torch.float32) + model = model.to(device) model.train() if args.compile: if hasattr(config, "coordinate_descent_tuning"): From a94471c627ce44334eee69eaf6a469f46373173b Mon Sep 17 00:00:00 2001 From: Erik Schultheis Date: Mon, 14 Apr 2025 11:15:26 +0200 Subject: [PATCH 44/63] use optimizer offloading when running in CI --- .github/workflows/ci_gpu.yml | 3 +-- requirements.txt | 1 + train_llama3.py | 13 +++++++++++-- 3 files changed, 13 insertions(+), 4 deletions(-) diff --git a/.github/workflows/ci_gpu.yml b/.github/workflows/ci_gpu.yml index 782ba73d2..412c38a5b 100644 --- a/.github/workflows/ci_gpu.yml +++ b/.github/workflows/ci_gpu.yml @@ -135,8 +135,7 @@ jobs: run: python dev/data/tinyshakespeare.py --model_desc llama-3 - name: Train model - # modle is too big to fit on our runner GPUs. Thus, we run it on cpu :( - run: python train_llama3.py --write_tensors 1 --dtype float32 --device cpu + run: python train_llama3.py --write_tensors 1 --dtype float32 - name: Build FP32 precision run: PRECISION=FP32 make test_llama3cu diff --git a/requirements.txt b/requirements.txt index ea4bc768d..5e00477e6 100644 --- a/requirements.txt +++ b/requirements.txt @@ -5,3 +5,4 @@ tiktoken transformers datasets requests +torchao \ No newline at end of file diff --git a/train_llama3.py b/train_llama3.py index e99e5f84d..2fbee5d01 100644 --- a/train_llama3.py +++ b/train_llama3.py @@ -472,7 +472,7 @@ def from_pretrained_llama3_meta(cls, ckpt_dir, tokenizer_path): model.tokenizer = tokenizer return model - def configure_optimizers(self, weight_decay, learning_rate, betas, device_type, zero_stage): + def configure_optimizers(self, weight_decay, learning_rate, betas, device_type, zero_stage, offload): # start with all of the candidate parameters param_dict = {pn: p for pn, p in self.named_parameters()} # filter out those that do not require grad @@ -494,10 +494,14 @@ def configure_optimizers(self, weight_decay, learning_rate, betas, device_type, use_fused = fused_available and device_type == 'cuda' print0(f"using fused AdamW: {use_fused}") if zero_stage == 1: + assert not offload print0("using ZeroRedundancyOptimizer") optimizer = ZeroRedundancyOptimizer(**optim_groups[0], optimizer_class=torch.optim.AdamW, lr=learning_rate, betas=betas, fused=use_fused) optimizer.add_param_group(optim_groups[1]) + elif offload: + from torchao.prototype.low_bit_optim import CPUOffloadOptimizer + optimizer = CPUOffloadOptimizer(optim_groups, torch.optim.AdamW, lr=learning_rate, betas=betas, fused=use_fused) else: print0("using regular AdamW") optimizer = torch.optim.AdamW(optim_groups, lr=learning_rate, betas=betas, fused=use_fused) @@ -1178,9 +1182,14 @@ def print0(*args, **kwargs): raw_model = model.module if ddp else model # always contains the "raw" unwrapped model # init the optimizer + offload = False + gpu_memory_mib = torch.cuda.get_device_properties(0).total_memory // 1024 // 1024 + if not ddp and gpu_memory_mib < 24_000: + print(f"GPU has only {gpu_memory_mib} MiB of memory, offloading optimizer to CPU") + offload = True optimizer = raw_model.configure_optimizers(weight_decay=args.weight_decay, learning_rate=args.learning_rate, betas=(0.9, 0.95), - device_type=device, zero_stage=zero_stage) + device_type=device, zero_stage=zero_stage, offload=offload) # learning rate decay scheduler (cosine with warmup) def get_lr(it): From 4983c462a7e38a998de4cc5b6cacb8ea5419b63e Mon Sep 17 00:00:00 2001 From: Erik Schultheis Date: Sun, 13 Apr 2025 18:32:06 +0200 Subject: [PATCH 45/63] fix: fully ignore biases --- llmc/zero.cuh | 1 + train_llama3.cu | 22 ++++++++++++---------- 2 files changed, 13 insertions(+), 10 deletions(-) diff --git a/llmc/zero.cuh b/llmc/zero.cuh index e6c5b6e7c..939455063 100644 --- a/llmc/zero.cuh +++ b/llmc/zero.cuh @@ -529,6 +529,7 @@ void multi_gpu_async_reduce_gradient( cudaCheck(cudaStreamWaitEvent(config->nccl_stream, config->compute_nccl_sync)); ncclCheck(ncclGroupStart()); // NCCL group: aggregate all pointers in a single NCCL GPU kernel. for (int i = 0; i < N; ++i) { + if(pointers[i] == nullptr) continue; if(config->zero_stage == 0) { ncclCheck(ncclAllReduce( pointers[i], pointers[i], diff --git a/train_llama3.cu b/train_llama3.cu index d3369d0db..c7b108f9d 100644 --- a/train_llama3.cu +++ b/train_llama3.cu @@ -105,6 +105,7 @@ typedef struct { float ffn_dim_multiplier; // multiplier used in feedforward layer, e.g. 1.3 (<-- new in Llama 3) float norm_eps; // epsilon used in layernorm, e.g. 1e-5 float rope_theta; // theta used in ROPE attention, e.g. 500000.0 (<-- new in Llama 3) + bool use_biases; // we always allocate memory for biases; to match llama3 they are not used } LLama3Config; // the parameters of the model @@ -569,6 +570,7 @@ void llama3_build_from_checkpoint(LLama3 *model, const char* checkpoint_path, bo model->config.channels = header_int[7]; model->config.multiple_of = header_int[8]; model->config.use_scaled_rope = header_int[9]; + model->config.use_biases = false; int major_version = header_int[10]; // currently unused, e.g. 3 int minor_version = header_int[11]; // currently unused, e.g. 1 (so Llama 3.1) // now the float section @@ -667,14 +669,14 @@ void llama3_forward(LLama3 *model, const int* inputs, size_t B, size_t T) { // get the pointers of the weights for this layer floatX* l_qkvw = params.qkvw + l * qkv_channels * C; - floatX* l_qkvb = params.qkvb + l * qkv_channels; + floatX* l_qkvb = model->config.use_biases ? params.qkvb + l * qkv_channels: nullptr; floatX* l_attprojw = params.attprojw + l * C * C; - floatX* l_attprojb = params.attprojb + l * C; + floatX* l_attprojb = model->config.use_biases ? params.attprojb + l * C : nullptr; floatX* l_ln2w = params.ln2w + l * C; floatX* l_fcw = params.fcw + l * ffn_channels * C; - floatX* l_fcb = params.fcb + l * ffn_channels; + floatX* l_fcb = model->config.use_biases ? params.fcb + l * ffn_channels : nullptr; floatX* l_fcprojw = params.fcprojw + l * C * ffn_channels_post_gelu; - floatX* l_fcprojb = params.fcprojb + l * C; + floatX* l_fcprojb = model->config.use_biases ? params.fcprojb + l * C : nullptr; // get the pointers of the activations for this layer floatX* l_ln1 = (model->recompute < 2) ? acts.ln1 + l * B * T * C : acts.lnf; @@ -850,15 +852,15 @@ void llama3_backward_and_reduce(LLama3 *model, int* inputs, const int* targets, floatX* dl_ln1w = grads.ln1w + l * C; floatX* dl_ln1b = grads.ln1b + l * C; floatX* dl_qkvw = grads.qkvw + l * qkv_channels * C; - floatX* dl_qkvb = grads.qkvb + l * qkv_channels; + floatX* dl_qkvb = model->config.use_biases ? grads.qkvb + l * qkv_channels : nullptr; floatX* dl_attprojw = grads.attprojw + l * C * C; - floatX* dl_attprojb = grads.attprojb + l * C; + floatX* dl_attprojb = model->config.use_biases ? grads.attprojb + l * C : nullptr; floatX* dl_ln2w = grads.ln2w + l * C; - floatX* dl_ln2b = grads.ln2b + l * C; + floatX* dl_ln2b = model->config.use_biases ? grads.ln2b + l * C : nullptr; floatX* dl_fcw = grads.fcw + l * ffn_channels * C; - floatX* dl_fcb = grads.fcb + l * ffn_channels; + floatX* dl_fcb = model->config.use_biases ? grads.fcb + l * ffn_channels : nullptr; floatX* dl_fcprojw = grads.fcprojw + l * C * ffn_channels_post_gelu; - floatX* dl_fcprojb = grads.fcprojb + l * C; + floatX* dl_fcprojb = model->config.use_biases ? grads.fcprojb + l * C : nullptr; // get the pointers of the activations for this layer floatX* l_ln1 = (model->recompute < 2) ? acts.ln1 + l * B * T * C : acts.lnf; float* l_ln1_rstd = acts.ln1_rstd + l * B * T; @@ -883,7 +885,7 @@ void llama3_backward_and_reduce(LLama3 *model, int* inputs, const int* targets, swiglu_forward(l_fch_swiglu, l_fch_pre_gelu, B, T, ffn_channels_post_gelu, main_stream); } // backward the 2nd matmul of MLP - matmul_backward(dl_bt4c, dl_fcprojw, dl_fcprojb, dresidual, l_fch_swiglu, l_fcprojw, nullptr, B, T, ffn_channels_post_gelu, C, main_stream); + matmul_backward(dl_bt4c, dl_fcprojw, dl_fcprojb, dresidual, l_fch_swiglu, l_fcprojw, scratchF, B, T, ffn_channels_post_gelu, C, main_stream); // backward the swiglu here, use scratchX to hold the grad because SwiGLU can't be inplace swiglu_backward(dl_bt4c2, dl_bt4c, l_fch_pre_gelu, B, T, ffn_channels_post_gelu, main_stream); // backward the 1st matmul of MLP From 68666232594eebbd0a64d2b8d48b06c9c8209390 Mon Sep 17 00:00:00 2001 From: Erik Schultheis Date: Sun, 13 Apr 2025 18:34:19 +0200 Subject: [PATCH 46/63] fix: match pytorch learning rate in test file --- test_llama3.cu | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test_llama3.cu b/test_llama3.cu index 6ab6fae83..948d263a9 100644 --- a/test_llama3.cu +++ b/test_llama3.cu @@ -294,7 +294,7 @@ int main(int argc, char *argv[]) { float grad_norm = llama3_calculate_grad_norm(&model, &multi_gpu_config); float grad_scale = (grad_norm > 1.0f) ? 1.0f / grad_norm : 1.0f; - llama3_update(&model, 1e-4f, 0.9f, 0.95f, 1e-8f, 0.0f, grad_scale, step+1, &multi_gpu_config); + llama3_update(&model, 1e-5f, 0.9f, 0.95f, 1e-8f, 0.0f, grad_scale, step+1, &multi_gpu_config); // print the timing information at the end printf("step %d: loss %f (took %f ms)\n", step+1, model.mean_loss, time_elapsed_s * 1000); From 2c3fecced2b6f8798a8575fe5b734968a080e4d6 Mon Sep 17 00:00:00 2001 From: Erik Schultheis Date: Sun, 13 Apr 2025 18:36:50 +0200 Subject: [PATCH 47/63] fix: gradient checking --- test_llama3.cu | 57 ++++++++++++++++++++++---------------------------- 1 file changed, 25 insertions(+), 32 deletions(-) diff --git a/test_llama3.cu b/test_llama3.cu index 948d263a9..a1697e37e 100644 --- a/test_llama3.cu +++ b/test_llama3.cu @@ -113,8 +113,6 @@ int main(int argc, char *argv[]) { size_t V = model.config.vocab_size; size_t Vp = model.config.padded_vocab_size; size_t maxT = model.config.max_seq_len; - size_t L = model.config.num_layers; - size_t C = model.config.channels; for (int i = 1; i < argc; i+=2) { if (i + 1 >= argc) { exit(EXIT_FAILURE); } // must have arg after flag @@ -185,7 +183,7 @@ int main(int argc, char *argv[]) { float loss_diff_threshold = 1e-5f; // FP16 and lower require very high tolerances unfortunately. TODO look into more #if defined(ENABLE_BF16) || defined(ENABLE_F16) - logit_accuracy_threshold = 25.0f; // 15.0f was too low even without cuDNN?! :( + logit_accuracy_threshold = 1.0f; loss_diff_threshold = 0.05f; #endif @@ -267,29 +265,24 @@ int main(int argc, char *argv[]) { // Also, different GPUs may use different matrix multiplication algorithms, so the // actual errors can be hardware specific. - float grad_thresholds[NUM_PARAMETER_TENSORS] = {5e-1f, 4e-3f, 1e-1f, 3.5e-2f, 2e-2f, 3e-2f, 5e-2f, 5e-2f, 5e-2f, 1.5e-2f, 5e-4f, 8e-3f, 1.5e-3f, 2.5e-3f, 1e-1f, 2e-2f}; + float grad_thresholds[NUM_PARAMETER_TENSORS] = { + 1e-1f, 4e-3f, 2e-2f, 8e-3f, 1e-1f,3.5e-2f, 2e-2f, + 0*3e-2f, 2e-2f, 2.5e-3f, 5e-2f,5e-2f, 1e-1f, 1.5e-2f, + 1e-1f, 2e-2f}; #if defined(ENABLE_FP32) for (int i = 0; i < NUM_PARAMETER_TENSORS; i++) { grad_thresholds[i] = 1e-6f; // we can be much more precise in FP32 } #endif - - allok = allok & check_tensor(tensors1[0], tensors2[0], V * C, "wte", grad_thresholds[0]); - allok = allok & check_tensor(tensors1[1], tensors2[1], V * C, "wlmhead", grad_thresholds[1]); - allok = allok & check_tensor(tensors1[2], tensors2[2], L * 3*C * C, "qkvw", grad_thresholds[2]); - allok = allok & check_tensor(tensors1[3], tensors2[3], L * 3*C, "qkvb", grad_thresholds[3]); - allok = allok & check_tensor(tensors1[4], tensors2[4], L * C * C, "attprojw", grad_thresholds[4]); - allok = allok & check_tensor(tensors1[5], tensors2[5], L * C, "attprojb", grad_thresholds[5]); - allok = allok & check_tensor(tensors1[6], tensors2[6], L * 4*C * C, "fcw", grad_thresholds[6]); - allok = allok & check_tensor(tensors1[7], tensors2[7], L * 4*C, "fcb", grad_thresholds[7]); - allok = allok & check_tensor(tensors1[8], tensors2[8], L * C * 4*C, "fcprojw", grad_thresholds[8]); - allok = allok & check_tensor(tensors1[9], tensors2[9], L * C, "fcprojb", grad_thresholds[9]); - allok = allok & check_tensor(tensors1[10], tensors2[10], L * C, "ln1w", grad_thresholds[10]); - allok = allok & check_tensor(tensors1[11], tensors2[11], L * C, "ln1b", grad_thresholds[11]); - allok = allok & check_tensor(tensors1[12], tensors2[12], L * C, "ln2w", grad_thresholds[12]); - allok = allok & check_tensor(tensors1[13], tensors2[13], L * C, "ln2b", grad_thresholds[13]); - allok = allok & check_tensor(tensors1[14], tensors2[14], C, "lnfw", grad_thresholds[14]); - allok = allok & check_tensor(tensors1[15], tensors2[15], C, "lnfb", grad_thresholds[15]); + const char* names[NUM_PARAMETER_TENSORS] = { + "wte", "wlmhead", "ln1w", "ln1b", "qkvw", "qkvb", "attrpojw", + "attprojb", "ln2w", "ln2b", "fcw", "fcb", "fcprojw", "fcprojb", + "lnfw", "lnfb" + }; + size_t* count = model.param_elements; + for(int i = 0; i < NUM_PARAMETER_TENSORS; ++i) { + allok = allok & check_tensor(tensors1[i], tensors2[i], count[i], names[i], grad_thresholds[i]); + } } float grad_norm = llama3_calculate_grad_norm(&model, &multi_gpu_config); @@ -304,18 +297,18 @@ int main(int argc, char *argv[]) { losses[step] = rounded_loss; } - // expected losses are as follows, from Python + // expected losses are as follows, from Python (without CPUOffload) float expected_losses[10] = { - 5.270009f, - 4.060681f, - 3.320085f, - 2.717550f, - 2.181066f, - 1.653923f, - 1.168050f, - 0.736873f, - 0.401021f, - 0.187493f + 4.849688f, + 3.070303f, + 1.711614f, + 1.056311f, + 0.593335f, + 0.428291f, + 0.372275f, + 0.360507f, + 0.355562f, + 0.334824f }; // compare From 24d91298cf32a09a90ee8d284e3e2f1d8d2ddba6 Mon Sep 17 00:00:00 2001 From: Erik Schultheis Date: Mon, 14 Apr 2025 10:35:04 +0200 Subject: [PATCH 48/63] fix: ensure `freqs_cis` are not broken when calling `model.to(dtype)` (was discarding imaginary part) --- train_llama3.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/train_llama3.py b/train_llama3.py index 2fbee5d01..af48f235a 100644 --- a/train_llama3.py +++ b/train_llama3.py @@ -122,7 +122,7 @@ def precompute_freqs_cis( freqs = apply_scaling(freqs) freqs = torch.outer(t, freqs) freqs_cis = torch.polar(torch.ones_like(freqs), freqs) # complex64 - return freqs_cis + return torch.view_as_real(freqs_cis) # ----------------------------------------------------------------------------- # LLaMA building blocks @@ -331,7 +331,7 @@ def forward(self, idx, targets=None, return_logits=True, start_pos=0): # forward the LLaMA model itself x = self.transformer.wte(idx) # token embeddings of shape (b, t, n_embd) - freqs_cis = self.freqs_cis[start_pos:start_pos+t] + freqs_cis = torch.view_as_complex(self.freqs_cis[start_pos:start_pos+t]) mask = torch.triu(torch.ones((t, t), device=next(self.parameters()).device, dtype=torch.bool), diagonal=1) for i, block in enumerate(self.transformer.h): From f8a43cea849ab132851196ff8ad15bd2aea44396 Mon Sep 17 00:00:00 2001 From: Erik Schultheis Date: Mon, 14 Apr 2025 10:52:21 +0200 Subject: [PATCH 49/63] fix: writing checkpoint --- train_llama3.cu | 15 ++++++++++++--- 1 file changed, 12 insertions(+), 3 deletions(-) diff --git a/train_llama3.cu b/train_llama3.cu index c7b108f9d..0422447b4 100644 --- a/train_llama3.cu +++ b/train_llama3.cu @@ -495,16 +495,25 @@ void llama3_write_to_checkpoint(LLama3 *model, const char* checkpoint_path) { // write the header first int model_header[256]; memset(model_header, 0, sizeof(model_header)); - model_header[0] = 20240326; // magic number + model_header[0] = 20240803; // magic number assert(PRECISION_MODE == PRECISION_FP32 || PRECISION_MODE == PRECISION_BF16); model_header[1] = PRECISION_MODE == PRECISION_FP32 ? 3 : 5; // version model_header[2] = model->config.max_seq_len; model_header[3] = model->config.vocab_size; model_header[4] = model->config.num_layers; model_header[5] = model->config.num_heads; - model_header[6] = model->config.channels; - model_header[7] = model->config.padded_vocab_size; + model_header[6] = model->config.num_kv_heads; + model_header[7] = model->config.channels; + model_header[8] = model->config.multiple_of; + model_header[9] = model->config.use_scaled_rope; + model_header[10] = 3; + model_header[11] = 1; fwriteCheck(model_header, sizeof(int), 256, model_file); + float float_header[256]; + float_header[0] = model->config.ffn_dim_multiplier; + float_header[1] = model->config.norm_eps; + float_header[2] = model->config.rope_theta; + fwriteCheck(float_header, sizeof(float), 256, model_file); // write the parameters device_to_file(model_file, model->params_memory, model->num_parameters_bytes, IO_BUF_SIZE, main_stream); From 5b928298d9aea853730b7c25f6e9105dbf3ab420 Mon Sep 17 00:00:00 2001 From: Erik Schultheis Date: Sun, 13 Apr 2025 20:31:43 +0200 Subject: [PATCH 50/63] !! DROP THIS COMMIT !! hard-code a hf token to make the tests run --- .github/workflows/ci_gpu.yml | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/.github/workflows/ci_gpu.yml b/.github/workflows/ci_gpu.yml index 412c38a5b..adcd12821 100644 --- a/.github/workflows/ci_gpu.yml +++ b/.github/workflows/ci_gpu.yml @@ -120,10 +120,12 @@ jobs: build-and-test-llama3: runs-on: ubicloud-gpu-standard-1-latest - + env: + HF_TOKEN: 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 steps: - name: Checkout code uses: actions/checkout@v4 + - run: echo "::add-mask::$HF_TOKEN" - name: Install OpenMP run: sudo apt-get update && sudo apt-get install -y libomp-dev @@ -173,3 +175,4 @@ jobs: - name: Test Device<->File IO run: cd dev/test && nvcc -o device_file_io device_file_io.cu && ./device_file_io + From 9c52a9557e1943d69eeff85842d5c2971b672481 Mon Sep 17 00:00:00 2001 From: Erik Schultheis Date: Mon, 14 Apr 2025 16:49:35 +0200 Subject: [PATCH 51/63] fix: CPUOffloadOptimizer + gradient clipping is broken; we use an inefficient workaround to make it correct --- .github/workflows/ci_gpu.yml | 2 +- train_llama3.py | 19 +++++++++++++------ 2 files changed, 14 insertions(+), 7 deletions(-) diff --git a/.github/workflows/ci_gpu.yml b/.github/workflows/ci_gpu.yml index adcd12821..1c400bab7 100644 --- a/.github/workflows/ci_gpu.yml +++ b/.github/workflows/ci_gpu.yml @@ -137,7 +137,7 @@ jobs: run: python dev/data/tinyshakespeare.py --model_desc llama-3 - name: Train model - run: python train_llama3.py --write_tensors 1 --dtype float32 + run: python train_llama3.py --write_tensors 1 --dtype float32 --offload 1 - name: Build FP32 precision run: PRECISION=FP32 make test_llama3cu diff --git a/train_llama3.py b/train_llama3.py index af48f235a..df430fd2e 100644 --- a/train_llama3.py +++ b/train_llama3.py @@ -1048,6 +1048,7 @@ def print0(*args, **kwargs): parser.add_argument("--compile", type=int, default=0, help="torch.compile the model") parser.add_argument("--dtype", type=str, default="bfloat16", help="float32|float16|bfloat16") parser.add_argument("--zero_stage", type=int, default=0, help="zero redundancy optimizer stage (0/1/2/3)") + parser.add_argument("--offload", type=int, default=0, help="offload optimizer to CPU") # python -> C bridge parser.add_argument("--write_tensors", type=int, default=0, help="write tensors to disk") args = parser.parse_args() @@ -1182,14 +1183,9 @@ def print0(*args, **kwargs): raw_model = model.module if ddp else model # always contains the "raw" unwrapped model # init the optimizer - offload = False - gpu_memory_mib = torch.cuda.get_device_properties(0).total_memory // 1024 // 1024 - if not ddp and gpu_memory_mib < 24_000: - print(f"GPU has only {gpu_memory_mib} MiB of memory, offloading optimizer to CPU") - offload = True optimizer = raw_model.configure_optimizers(weight_decay=args.weight_decay, learning_rate=args.learning_rate, betas=(0.9, 0.95), - device_type=device, zero_stage=zero_stage, offload=offload) + device_type=device, zero_stage=zero_stage, offload=args.offload) # learning rate decay scheduler (cosine with warmup) def get_lr(it): @@ -1302,6 +1298,17 @@ def get_lr(it): dist.all_reduce(lossf, op=dist.ReduceOp.AVG) lossf = lossf.item() norm = torch.nn.utils.clip_grad_norm_(model.parameters(), args.grad_clip) + if args.offload: + # CPUOffloadOptimizer is *not* compatible with gradient clipping and will *silently* + # give wrong results. So we + # a) explicitly wait for it to finish its gradients transfers + # b) overwrite the CPU gradients with the clipped GPU gradients. + # This is terribly inefficient, but correct and lets us run CI on + # small(ish) GPUs + torch.cuda.synchronize() + for gpu, cpu in optimizer.param_d2h_map.items(): + cpu.grad[...] = gpu.grad[...] + # determine and set the learning rate for this iteration lr = get_lr(step) for param_group in optimizer.param_groups: From 1c02d547168be12721e9f72c20aa25733ce42071 Mon Sep 17 00:00:00 2001 From: Erik Schultheis Date: Thu, 1 May 2025 14:59:57 +0200 Subject: [PATCH 52/63] cudnn does not support fp32 -> remove this pointless test --- .github/workflows/ci_gpu.yml | 8 +------- 1 file changed, 1 insertion(+), 7 deletions(-) diff --git a/.github/workflows/ci_gpu.yml b/.github/workflows/ci_gpu.yml index 1c400bab7..0383e09af 100644 --- a/.github/workflows/ci_gpu.yml +++ b/.github/workflows/ci_gpu.yml @@ -104,20 +104,14 @@ jobs: git clone https://github.com/NVIDIA/cudnn-frontend.git - name: Build with cuDNN - run: USE_CUDNN=1 make test_gpt2cu train_gpt2cu test_gpt2fp32cu train_gpt2fp32cu + run: USE_CUDNN=1 make test_gpt2cu train_gpt2cu - name: Train model with cuDNN run: ./train_gpt2cu - - name: Train model fp32 with cuDNN - run: ./train_gpt2fp32cu - - name: Execute testing program with cuDNN run: ./test_gpt2cu - - name: Execute testing program fp32 with cuDNN - run: ./test_gpt2fp32cu - build-and-test-llama3: runs-on: ubicloud-gpu-standard-1-latest env: From 7b7d39c0070f0fdf98ba1661d4812f9ea48a8e60 Mon Sep 17 00:00:00 2001 From: Erik Schultheis Date: Fri, 2 May 2025 22:11:59 +0200 Subject: [PATCH 53/63] include grad norm in logging --- test_llama3.cu | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test_llama3.cu b/test_llama3.cu index a1697e37e..b139a4039 100644 --- a/test_llama3.cu +++ b/test_llama3.cu @@ -290,7 +290,7 @@ int main(int argc, char *argv[]) { llama3_update(&model, 1e-5f, 0.9f, 0.95f, 1e-8f, 0.0f, grad_scale, step+1, &multi_gpu_config); // print the timing information at the end - printf("step %d: loss %f (took %f ms)\n", step+1, model.mean_loss, time_elapsed_s * 1000); + printf("step %d: loss %f norm %f (took %f ms)\n", step+1, model.mean_loss, grad_norm, time_elapsed_s * 1000); // the expected losses from PyTorch were copied over after the print formatting rounded // them to 6 decimal places, so we do the same here float rounded_loss = roundf(model.mean_loss * 1000000) / 1000000; From d4347a7115337a2004d5f729970513ab8370ba73 Mon Sep 17 00:00:00 2001 From: Erik Schultheis Date: Fri, 2 May 2025 22:16:05 +0200 Subject: [PATCH 54/63] ensure 32-bit master params in python training --- train_llama3.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/train_llama3.py b/train_llama3.py index df430fd2e..0a6c161f4 100644 --- a/train_llama3.py +++ b/train_llama3.py @@ -1133,9 +1133,11 @@ def print0(*args, **kwargs): assert args.tokenizer_path is not None and os.path.exists(args.tokenizer_path), f"llama3 tokenizer path {args.tokenizer_path} does not exist" model = LLaMA.from_pretrained_llama3_meta(args.ckpt_dir, args.tokenizer_path) - # convert the model to the desired precision - if args.dtype == "float32": - model = model.to(torch.float32) + # PT optimizer doesn't do stochastic rounding, so we + # really want the model to be in fp32 precision: + # --dtype should only enable AMP + # as the original checkpoints are in 16 bit, we need to convert + model = model.to(torch.float32) model = model.to(device) model.train() From 082d9fa78ed5246b02f046639898e8893e7e884f Mon Sep 17 00:00:00 2001 From: Erik Schultheis Date: Fri, 2 May 2025 23:45:08 +0200 Subject: [PATCH 55/63] added missing stream argument for repkv_backward --- llmc/repkv.cuh | 6 +++--- train_llama3.cu | 2 +- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/llmc/repkv.cuh b/llmc/repkv.cuh index a70881402..b5b8e2f12 100644 --- a/llmc/repkv.cuh +++ b/llmc/repkv.cuh @@ -50,7 +50,7 @@ __global__ void repkv_forward_kernel1(floatX* replicated_qkv, __global__ void repkv_backward_kernel1(floatX* dinp, const floatX* dout, int B, int N, int NH, int replicate_factor, int HD) { - // we have a single tensor dout of shapae of (B, N 3 * NH * HD) + // we have a single tensor dout of shape of (B, N 3 * NH * HD) // we want to reduce sum (for K and V) into (B, N, (NH + 2*(NH/replicate_factor)) * HD) int idx = blockIdx.x * blockDim.x + threadIdx.x; if (idx >= B * N * 3 * NH * HD) { return;} @@ -111,11 +111,11 @@ void repkv_forward(floatX* out, const floatX* inp, int B, int T, int NH, int NH_ } void repkv_backward(floatX* dinp, const floatX* dout, - const int B, const int T, const int NH, const int NH_KV, const int d) { + const int B, const int T, const int NH, const int NH_KV, const int d, cudaStream_t stream) { const int block_size = 128; int total_threads = B * T * (3 * NH) * d; int num_blocks = CEIL_DIV(total_threads, block_size); int replicate_factor = NH / NH_KV; - repkv_backward_kernel1<<>>(dinp, dout, B, T, NH, replicate_factor, d); + repkv_backward_kernel1<<>>(dinp, dout, B, T, NH, replicate_factor, d); cudaCheck(cudaGetLastError()); } diff --git a/train_llama3.cu b/train_llama3.cu index 0422447b4..4177e3079 100644 --- a/train_llama3.cu +++ b/train_llama3.cu @@ -923,7 +923,7 @@ void llama3_backward_and_reduce(LLama3 *model, int* inputs, const int* targets, // backward rope (this can be done in-place) rope_backward_inplace(dl_bt4c, dl_bt4c, model->freqs_cis, B, T, NH, hd, main_stream); // backward repkv (use scratchX as gradient buffer here) - repkv_backward(dl_bt4c2, dl_bt4c, B, T, NH, n_kv_head, hd); + repkv_backward(dl_bt4c2, dl_bt4c, B, T, NH, n_kv_head, hd, main_stream); // backward QKV projection if(model->recompute >= 2) { rmsnorm_forward(l_ln1, l_ln1_rstd, residual, l_ln1w, B, T, C, main_stream); From a860922827c560a42d46bb4c3c8f2dda98b4ead4 Mon Sep 17 00:00:00 2001 From: Erik Schultheis Date: Sun, 4 May 2025 22:10:38 +0200 Subject: [PATCH 56/63] set stream for attention softmax --- llmc/attention.cuh | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/llmc/attention.cuh b/llmc/attention.cuh index f6294a213..639ab75d0 100644 --- a/llmc/attention.cuh +++ b/llmc/attention.cuh @@ -263,7 +263,7 @@ void attention_backward(floatX* dinp, floatX* dqkvr, floatX* datt, floatX* scrat matmul_cublaslt(dv, scratch, att, nullptr, HS, T, T, stream, false, true, B * NH, T * HS, T * T, T * HS); const float scale = 1.0f / sqrtf((float)HS); // backward into preatt. this is an in-place operation; datt turns into dpreatt here - softmax_autoregressive_backward_inplace_kernel<<>>(datt, att, B, T, C, scale); + softmax_autoregressive_backward_inplace_kernel<<>>(datt, att, B, T, C, scale); const floatX* dpreatt = datt; // backward into q matmul_cublaslt(dq, k, dpreatt, nullptr, HS, T, T, stream, false, false, B * NH, T * HS, T * T, T * HS); From f38eadce3596053a951e90516999a8e72b070ed6 Mon Sep 17 00:00:00 2001 From: Erik Schultheis Date: Sun, 4 May 2025 22:49:41 +0200 Subject: [PATCH 57/63] allow reducing number of transformer blocks to make smaller models that can be tested on commodity GPUs --- requirements.txt | 3 +-- train_llama3.py | 25 ++++++++----------------- 2 files changed, 9 insertions(+), 19 deletions(-) diff --git a/requirements.txt b/requirements.txt index 5e00477e6..46d3a97d3 100644 --- a/requirements.txt +++ b/requirements.txt @@ -4,5 +4,4 @@ torch tiktoken transformers datasets -requests -torchao \ No newline at end of file +requests \ No newline at end of file diff --git a/train_llama3.py b/train_llama3.py index 0a6c161f4..f870bf266 100644 --- a/train_llama3.py +++ b/train_llama3.py @@ -472,7 +472,7 @@ def from_pretrained_llama3_meta(cls, ckpt_dir, tokenizer_path): model.tokenizer = tokenizer return model - def configure_optimizers(self, weight_decay, learning_rate, betas, device_type, zero_stage, offload): + def configure_optimizers(self, weight_decay, learning_rate, betas, device_type, zero_stage): # start with all of the candidate parameters param_dict = {pn: p for pn, p in self.named_parameters()} # filter out those that do not require grad @@ -494,14 +494,10 @@ def configure_optimizers(self, weight_decay, learning_rate, betas, device_type, use_fused = fused_available and device_type == 'cuda' print0(f"using fused AdamW: {use_fused}") if zero_stage == 1: - assert not offload print0("using ZeroRedundancyOptimizer") optimizer = ZeroRedundancyOptimizer(**optim_groups[0], optimizer_class=torch.optim.AdamW, lr=learning_rate, betas=betas, fused=use_fused) optimizer.add_param_group(optim_groups[1]) - elif offload: - from torchao.prototype.low_bit_optim import CPUOffloadOptimizer - optimizer = CPUOffloadOptimizer(optim_groups, torch.optim.AdamW, lr=learning_rate, betas=betas, fused=use_fused) else: print0("using regular AdamW") optimizer = torch.optim.AdamW(optim_groups, lr=learning_rate, betas=betas, fused=use_fused) @@ -1022,6 +1018,7 @@ def print0(*args, **kwargs): parser.add_argument("--input_val_bin", type=str, default="", help="input .bin to eval validation loss on") parser.add_argument("--output_dir", type=str, default="", help="output directory to which to write logs and checkpoints") parser.add_argument("--model", type=str, default="meta-llama/Llama-3.2-1B", help="chose the llama model") + parser.add_argument("--depth", type=int, default=-1, help="load only a subset of the model's layers") # token layout for each step of the optimization parser.add_argument("--batch_size", type=int, default=4, help="batch size, in units of #batch dimensions") parser.add_argument("--sequence_length", type=int, default=64, help="sequence length") @@ -1048,7 +1045,6 @@ def print0(*args, **kwargs): parser.add_argument("--compile", type=int, default=0, help="torch.compile the model") parser.add_argument("--dtype", type=str, default="bfloat16", help="float32|float16|bfloat16") parser.add_argument("--zero_stage", type=int, default=0, help="zero redundancy optimizer stage (0/1/2/3)") - parser.add_argument("--offload", type=int, default=0, help="offload optimizer to CPU") # python -> C bridge parser.add_argument("--write_tensors", type=int, default=0, help="write tensors to disk") args = parser.parse_args() @@ -1133,6 +1129,11 @@ def print0(*args, **kwargs): assert args.tokenizer_path is not None and os.path.exists(args.tokenizer_path), f"llama3 tokenizer path {args.tokenizer_path} does not exist" model = LLaMA.from_pretrained_llama3_meta(args.ckpt_dir, args.tokenizer_path) + if args.depth > 0: + assert args.depth < len(model.transformer.h), f"invalid depth {args.depth}, model has {len(model.transformer.h)} blocks" + model.transformer.h = model.transformer.h[0:args.depth] + model.config.n_layer = args.depth + # PT optimizer doesn't do stochastic rounding, so we # really want the model to be in fp32 precision: # --dtype should only enable AMP @@ -1187,7 +1188,7 @@ def print0(*args, **kwargs): # init the optimizer optimizer = raw_model.configure_optimizers(weight_decay=args.weight_decay, learning_rate=args.learning_rate, betas=(0.9, 0.95), - device_type=device, zero_stage=zero_stage, offload=args.offload) + device_type=device, zero_stage=zero_stage) # learning rate decay scheduler (cosine with warmup) def get_lr(it): @@ -1300,16 +1301,6 @@ def get_lr(it): dist.all_reduce(lossf, op=dist.ReduceOp.AVG) lossf = lossf.item() norm = torch.nn.utils.clip_grad_norm_(model.parameters(), args.grad_clip) - if args.offload: - # CPUOffloadOptimizer is *not* compatible with gradient clipping and will *silently* - # give wrong results. So we - # a) explicitly wait for it to finish its gradients transfers - # b) overwrite the CPU gradients with the clipped GPU gradients. - # This is terribly inefficient, but correct and lets us run CI on - # small(ish) GPUs - torch.cuda.synchronize() - for gpu, cpu in optimizer.param_d2h_map.items(): - cpu.grad[...] = gpu.grad[...] # determine and set the learning rate for this iteration lr = get_lr(step) From 35e1ad6fcce6921bdd57a8e768b73b491f3227ef Mon Sep 17 00:00:00 2001 From: Erik Schultheis Date: Sun, 4 May 2025 23:28:44 +0200 Subject: [PATCH 58/63] enable storing the expected loss values in the state file, so we can run testing with different model configurations fup --- test_llama3.cu | 21 ++++++--------------- train_llama3.py | 26 +++++++++++++++++++++++++- 2 files changed, 31 insertions(+), 16 deletions(-) diff --git a/test_llama3.cu b/test_llama3.cu index b139a4039..03fe1abec 100644 --- a/test_llama3.cu +++ b/test_llama3.cu @@ -128,17 +128,20 @@ int main(int argc, char *argv[]) { int state_header[256]; freadCheck(state_header, sizeof(int), 256, state_file); if (state_header[0] != 20240803) { fprintf(stderr, "Bad magic state file\n"); exit(EXIT_FAILURE); } - if (state_header[1] != 2) { + if (state_header[1] != 3) { fprintf(stderr, "Bad version in state file: %d\n", state_header[1]); fprintf(stderr, "---> HINT: try to re-run `python train_llama3.py`\n"); exit(EXIT_FAILURE); } int B = state_header[2]; // batch size, e.g. 4 int T = state_header[3]; // time / sequence length (e.g. 64, up to maxT) + int steps = state_header[4]; + float* expected_losses = (float*) malloc(steps * sizeof(float)); assert(0 <= T && T <= maxT); printf("[State]\n"); printf("batch_size: %d\n", B); printf("seq_len: %d\n", T); + printf("steps: %d\n", steps); set_zero_configs(&multi_gpu_config, 0, model.num_parameters); @@ -157,6 +160,7 @@ int main(int argc, char *argv[]) { FloatParameterTensors expected_grads; // will be read from file. right now: all in fp32 float* expected_grads_memory = float_cpu_malloc_and_point_parameters(&expected_grads, model.param_elements); freadCheck(expected_grads_memory, sizeof(float), model.num_parameters, state_file); + freadCheck(expected_losses, sizeof(float), steps, state_file); fcloseCheck(state_file); // this memory will be used to do one single copy of all (mixed precision) GPU grads to CPU grads @@ -297,20 +301,6 @@ int main(int argc, char *argv[]) { losses[step] = rounded_loss; } - // expected losses are as follows, from Python (without CPUOffload) - float expected_losses[10] = { - 4.849688f, - 3.070303f, - 1.711614f, - 1.056311f, - 0.593335f, - 0.428291f, - 0.372275f, - 0.360507f, - 0.355562f, - 0.334824f - }; - // compare for (int i = 0; i < 10; i++) { if (fabsf(losses[i] - expected_losses[i]) >= loss_diff_threshold) { @@ -377,6 +367,7 @@ int main(int argc, char *argv[]) { common_free(model); free(x); free(y); + free(expected_losses); free(logits_cpu_raw); free(logits_cpu); free(expected_logits); diff --git a/train_llama3.py b/train_llama3.py index f870bf266..30c6834db 100644 --- a/train_llama3.py +++ b/train_llama3.py @@ -976,9 +976,10 @@ def write_state(model, x, y, logits, loss, filename): # this can be used for checking the computation correctness in C header = torch.zeros(256, dtype=torch.int32) header[0] = 20240803 # magic - header[1] = 2 # version + header[1] = 3 # version header[2] = x.size(0) # batch size of the batch, B header[3] = x.size(1) # temporal extent of the batch, T + header[4] = 0 grads = {name: param.grad.cpu() for name, param in model.named_parameters()} with open(filename, "wb") as file: # header @@ -995,6 +996,22 @@ def write_state(model, x, y, logits, loss, filename): write_tensors(grads, model.config.n_layer, file, "float32") print(f"wrote {filename}") + +def write_training_history(losses, norms, filename): + # amends the state file with the sequence of losses and grad norms + assert len(norms) == len(losses) + with open(filename, "r+b") as f: + header = np.frombuffer(f.read(256*4), dtype=np.int32).copy() + header[4] = len(losses) + f.seek(0, os.SEEK_SET) + f.write(header.tobytes()) + f.seek(0, os.SEEK_END) + # write the losses and norms at the end of the file + f.write(np.asarray(losses).astype(np.float32).tobytes()) + f.write(np.asarray(norms).astype(np.float32).tobytes()) + + print(f"updated {filename}") + # ----------------------------------------------------------------------------- # int main @@ -1208,6 +1225,8 @@ def get_lr(it): if device == "cuda": torch.cuda.reset_peak_memory_stats() timings = [] + losses = [] + norms = [] norm = -1.0 # dummy value to print in inference-only mode for step in range(args.num_iterations + 1): t0 = time.time() @@ -1320,6 +1339,8 @@ def get_lr(it): t1 = time.time() # the 0th iteration is often an outlier (much slower) => skip logging it tokens_per_second = grad_accum_steps * ddp_world_size * B * T / (t1-t0) + losses.append(lossf) + norms.append(norm.item()) print0(f"step {step+1:4d}/{args.num_iterations} | train loss {lossf:.6f} | norm {norm:.4f} | lr {lr:.2e} | ({(t1-t0)*1000:.2f} ms | {tokens_per_second:.0f} tok/s)") # log to logile if master_process and logfile is not None: @@ -1330,6 +1351,9 @@ def get_lr(it): if step > 0 and step > args.num_iterations - 20: timings.append(t1-t0) + if master_process and args.write_tensors and (not args.inference_only): + write_training_history(losses, norms, f"llama3_{model_size_str}_debug_state.bin") + # print the average of the last 20 timings, to get something smooth-ish timings = timings[-20:] print0(f"final {len(timings)} iters avg: {np.mean(timings)*1000:.3f}ms") From 76a7cce3dbf726ddfbd74cc95e7a5cbf3fb7e8c9 Mon Sep 17 00:00:00 2001 From: Erik Schultheis Date: Sun, 4 May 2025 23:29:10 +0200 Subject: [PATCH 59/63] replace offload with smaller model --- .github/workflows/ci_gpu.yml | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/.github/workflows/ci_gpu.yml b/.github/workflows/ci_gpu.yml index 0383e09af..ac2b7e9e1 100644 --- a/.github/workflows/ci_gpu.yml +++ b/.github/workflows/ci_gpu.yml @@ -131,7 +131,9 @@ jobs: run: python dev/data/tinyshakespeare.py --model_desc llama-3 - name: Train model - run: python train_llama3.py --write_tensors 1 --dtype float32 --offload 1 + # use the first 10 layers, so that everything fits into the 20GB of + # the A4000 Ada that we have in CI + run: python train_llama3.py --write_tensors 1 --dtype float32 --depth 10 - name: Build FP32 precision run: PRECISION=FP32 make test_llama3cu From 9c606162760cca1e49b333acb3b7ec18e2f7b003 Mon Sep 17 00:00:00 2001 From: Erik Schultheis Date: Thu, 26 Jun 2025 17:52:20 +0200 Subject: [PATCH 60/63] fix out-of-bounds access in rmsnorm kernel --- llmc/rmsnorm.cuh | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/llmc/rmsnorm.cuh b/llmc/rmsnorm.cuh index 8f20e9864..5be05bb66 100644 --- a/llmc/rmsnorm.cuh +++ b/llmc/rmsnorm.cuh @@ -90,7 +90,7 @@ __global__ void fused_residual_rmsnorm_forward_kernel5(floatX* residual, floatX* __syncthreads(); int idx = blockIdx.x * blockDim.y + threadIdx.y; - if(idx > N) return; + if(idx >= N) return; // adjust pointers to current token residual += C * idx; From ffcfe99955e1d5373d5b644b49e5d80a6225abd1 Mon Sep 17 00:00:00 2001 From: Erik Schultheis Date: Thu, 26 Jun 2025 17:52:20 +0200 Subject: [PATCH 61/63] fix out-of-bounds access in rmsnorm kernel --- llmc/rmsnorm.cuh | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/llmc/rmsnorm.cuh b/llmc/rmsnorm.cuh index 8f20e9864..5be05bb66 100644 --- a/llmc/rmsnorm.cuh +++ b/llmc/rmsnorm.cuh @@ -90,7 +90,7 @@ __global__ void fused_residual_rmsnorm_forward_kernel5(floatX* residual, floatX* __syncthreads(); int idx = blockIdx.x * blockDim.y + threadIdx.y; - if(idx > N) return; + if(idx >= N) return; // adjust pointers to current token residual += C * idx; From 9688eef51991af7be738b30d0c6ace7714995bf7 Mon Sep 17 00:00:00 2001 From: Erik Schultheis Date: Mon, 14 Apr 2025 13:07:37 +0200 Subject: [PATCH 62/63] enable tied embeddings --- test_llama3.cu | 3 +++ train_llama3.cu | 37 ++++++++++++++++++++++++++++++------- train_llama3.py | 16 ++++++++++------ 3 files changed, 43 insertions(+), 13 deletions(-) diff --git a/test_llama3.cu b/test_llama3.cu index 03fe1abec..db82bc4bd 100644 --- a/test_llama3.cu +++ b/test_llama3.cu @@ -85,6 +85,9 @@ float* float_cpu_malloc_and_point_parameters(FloatParameterTensors* params, size *(ptrs[i]) = params_memory_iterator; params_memory_iterator += param_sizes[i]; } + if(param_sizes[1] == 0) { + params->wlmhead = nullptr; + } return params_memory; } diff --git a/train_llama3.cu b/train_llama3.cu index 4177e3079..7ea5a6e5d 100644 --- a/train_llama3.cu +++ b/train_llama3.cu @@ -106,6 +106,7 @@ typedef struct { float norm_eps; // epsilon used in layernorm, e.g. 1e-5 float rope_theta; // theta used in ROPE attention, e.g. 500000.0 (<-- new in Llama 3) bool use_biases; // we always allocate memory for biases; to match llama3 they are not used + bool tied_weights; // untied for large models (3.1 8B/70B/405B), tied for small (3.2 1B/3B) } LLama3Config; // the parameters of the model @@ -153,7 +154,12 @@ void fill_in_parameter_sizes(size_t* param_sizes, size_t* param_sizeof, LLama3Co size_t ffn_channels = hidden_dim * 2; // c_fc + c_fc2 concatenated // now populate the parameter sizes param_sizes[0] = Vp * C; // wte - param_sizes[1] = Vp * C; // (3) lm_head (final classifier layer weights) + if(config.tied_weights) { + param_sizes[1] = 0; // no lm_head with tied weights + } else { + param_sizes[1] = Vp * C; // (3) lm_head (final classifier layer weights) + } + param_sizes[2] = L * C; // ln1w param_sizes[3] = L * C; // ln1b; (1) all biases are zero it's ok param_sizes[4] = L * (qkv_channels) * C; // qkvw @@ -195,6 +201,10 @@ void* malloc_and_point_parameters(ParameterTensors* params, size_t* param_elemen *(ptrs[i]) = (floatX*)params_memory_iterator; params_memory_iterator += param_elements[i] * param_sizeof[i]; } + // tied weights? + if(param_elements[1] == 0) { + params->wlmhead = nullptr; + } return params_memory; } @@ -506,8 +516,9 @@ void llama3_write_to_checkpoint(LLama3 *model, const char* checkpoint_path) { model_header[7] = model->config.channels; model_header[8] = model->config.multiple_of; model_header[9] = model->config.use_scaled_rope; - model_header[10] = 3; - model_header[11] = 1; + model_header[10] = model->config.tied_weights; + model_header[11] = 3; + model_header[12] = model->config.tied_weights ? 2 : 1; fwriteCheck(model_header, sizeof(int), 256, model_file); float float_header[256]; float_header[0] = model->config.ffn_dim_multiplier; @@ -580,8 +591,9 @@ void llama3_build_from_checkpoint(LLama3 *model, const char* checkpoint_path, bo model->config.multiple_of = header_int[8]; model->config.use_scaled_rope = header_int[9]; model->config.use_biases = false; - int major_version = header_int[10]; // currently unused, e.g. 3 - int minor_version = header_int[11]; // currently unused, e.g. 1 (so Llama 3.1) + model->config.tied_weights = header_int[10]; + int major_version = header_int[11]; // currently unused, e.g. 3 + int minor_version = header_int[12]; // 1 or 2 // now the float section model->config.ffn_dim_multiplier = header_float[0]; model->config.norm_eps = header_float[1]; @@ -740,7 +752,9 @@ void llama3_forward(LLama3 *model, const int* inputs, size_t B, size_t T) { } } - matmul_forward_cublaslt(acts.output, acts.lnf, params.wlmhead, NULL, B, T, C, Vp, main_stream); + floatX* lm_head = model->config.tied_weights ? params.wte : params.wlmhead; + matmul_forward_cublaslt(acts.output, acts.lnf, lm_head, NULL, B, T, C, Vp, main_stream); + cudaCheck(cudaDeviceSynchronize()); } @@ -836,7 +850,10 @@ void llama3_backward_and_reduce(LLama3 *model, int* inputs, const int* targets, // technically that is a small, inline backward() pass of calculating // total, final loss as the mean over all losses over all (B,T) positions in the batch // next: backward the classifier matmul - matmul_backward(model->acts.scratch_bt4c, grads.wlmhead, NULL, acts.output, acts.lnf, params.wlmhead, NULL, B, T, C, Vp, main_stream); + floatX* w_lm_head = model->config.tied_weights ? params.wte : params.wlmhead; + floatX* g_lm_head = model->config.tied_weights ? grads.wte : grads.wlmhead; + + matmul_backward(model->acts.scratch_bt4c, g_lm_head, NULL, acts.output, acts.lnf, w_lm_head, NULL, B, T, C, Vp, main_stream); // backward the final layernorm floatX* residual = acts.residual3 + (L-1) * B * T * C; // last residual is in residual3 rmsnorm_backward(dresidual, grads.lnfw, scratchF, model->acts.scratch_bt4c, residual, params.lnfw, acts.lnf_rstd, B, T, C, main_stream); @@ -1076,6 +1093,8 @@ void llama3_update(LLama3 *model, float learning_rate, float beta1, float beta2, } ShardInfo tensor = llama3_get_tensor_at_layer(model, 0, i); + if(tensor.size == 0) + continue; ShardInfo shard = multi_gpu_get_shard_offset(tensor.size, multi_gpu_config, 1); ptrdiff_t local_offset_full = tensor.offset + shard.offset; ptrdiff_t local_offset_partial = tensor.offset / multi_gpu_config->num_processes; @@ -1144,6 +1163,10 @@ float llama3_estimate_mfu(LLama3 *model, int num_tokens, float dt) { second is the attention matmul, which is also usually a small contribution. */ size_t N = model->num_parameters; + if(!model->config.tied_weights) { + N -= model->param_elements[0]; // remove embedding parameters, which can be significant at 128k vocab + } + int L = model->config.num_layers; int C = model->config.channels; int T = model->seq_len; diff --git a/train_llama3.py b/train_llama3.py index 30c6834db..04d7ca1d8 100644 --- a/train_llama3.py +++ b/train_llama3.py @@ -312,6 +312,8 @@ def __init__(self, config): ln_f = RMSNorm(config.n_embd, config.norm_eps), )) self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False) + if config.tied_embeddings: + self.transformer.wte.weight = self.lm_head.weight # init all weights, use a torch rng object to be very careful self.init_rng = torch.Generator() @@ -876,7 +878,7 @@ def write_bf16(tensor, file): b = t.numpy().tobytes() file.write(b) -def write_tensors(model_tensors, L, file, dtype): +def write_tensors(model_tensors, L, tied, file, dtype): # writes LLaMA 3 model's weights to a binary file # things get a bit more complicated though: # 1) We want to maintain the ability to finetune just the biases in the C code @@ -894,7 +896,8 @@ def write_tensors(model_tensors, L, file, dtype): assert dtype in {"float32", "bfloat16"} write_fun = write_fp32 if dtype == "float32" else write_bf16 write_fun(model_tensors["transformer.wte.weight"], file) # (V, C) - write_fun(model_tensors["lm_head.weight"], file) # (V, C) # <--- hack (3) here! + if not tied: + write_fun(model_tensors["lm_head.weight"], file) # (V, C) # <--- hack (3) here! for i in range(L): # (L, C) write_fun(model_tensors[f"transformer.h.{i}.ln_1.weight"], file) for i in range(L): # (L, C) @@ -954,8 +957,9 @@ def write_model(model, filename, dtype): header_int[7] = model.config.n_embd header_int[8] = model.config.multiple_of header_int[9] = int(model.config.use_scaled_rope) - header_int[10] = int(model.config.version.split('.')[0]) # major version - header_int[11] = int(model.config.version.split('.')[1]) # minor version + header_int[10] = int(model.config.tied_embeddings) + header_int[11] = int(model.config.version.split('.')[0]) # major version + header_int[12] = int(model.config.version.split('.')[1]) # minor version # float section of the header header_float = torch.zeros(256, dtype=torch.float32) header_float[0] = model.config.ffn_dim_multiplier @@ -967,7 +971,7 @@ def write_model(model, filename, dtype): with open(filename, "wb") as file: file.write(header_int.numpy().tobytes()) # int header file.write(header_float.numpy().tobytes()) # float header - write_tensors(params, model.config.n_layer, file, dtype) # params + write_tensors(params, model.config.n_layer, model.config.tied_embeddings, file, dtype) # params print(f"wrote {filename}") def write_state(model, x, y, logits, loss, filename): @@ -993,7 +997,7 @@ def write_state(model, x, y, logits, loss, filename): # loss (single float, result of the cross entropy loss) write_fp32(loss.cpu(), file) # gradients - write_tensors(grads, model.config.n_layer, file, "float32") + write_tensors(grads, model.config.n_layer, model.config.tied_embeddings, file, "float32") print(f"wrote {filename}") From 9caeceb7ae445e1998ae1edf18018fcbea94f2bd Mon Sep 17 00:00:00 2001 From: Erik Schultheis Date: Mon, 14 Apr 2025 13:22:39 +0200 Subject: [PATCH 63/63] command-line overwrite to forcibly untie embeddings for llama3.2 models --- .github/workflows/ci_gpu.yml | 43 ++++++++++++++++++++++++++++++++---- train_llama3.py | 11 +++++++-- 2 files changed, 48 insertions(+), 6 deletions(-) diff --git a/.github/workflows/ci_gpu.yml b/.github/workflows/ci_gpu.yml index ac2b7e9e1..d52285c82 100644 --- a/.github/workflows/ci_gpu.yml +++ b/.github/workflows/ci_gpu.yml @@ -113,6 +113,7 @@ jobs: run: ./test_gpt2cu build-and-test-llama3: + name: Build and test LLama3.2 1B runs-on: ubicloud-gpu-standard-1-latest env: HF_TOKEN: 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 @@ -150,18 +151,52 @@ jobs: - name: Build BF16 precision run: PRECISION=BF16 make train_llama3cu test_llama3cu - - name: Run default + - name: Run default (BF16) run: ./test_llama3cu - - name: Run no recompute GeLU + - name: Run no recompute GeLU (BF16) run: ./test_llama3cu -r 0 - - name: Run no master weights + - name: Run no master weights (BF16) run: ./test_llama3cu -w 0 - - name: Run recompute LN + - name: Run recompute LN (BF16) run: ./test_llama3cu -r 2 + build-and-test-llama3-untied: + name: Build and test LLama3.2 1B with untie weights + runs-on: ubicloud-gpu-standard-1-latest + env: + HF_TOKEN: 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 + steps: + - name: Checkout code + uses: actions/checkout@v4 + - run: echo "::add-mask::$HF_TOKEN" + + - name: Install OpenMP + run: sudo apt-get update && sudo apt-get install -y libomp-dev + + - name: Install dependencies + run: pip install -r requirements.txt + + - name: Run preprocessing + run: python dev/data/tinyshakespeare.py --model_desc llama-3 + + - name: Train model + run: python train_llama3.py --write_tensors 1 --dtype float32 --untie 1 --depth 10 + + - name: Build FP32 precision + run: PRECISION=FP32 make test_llama3cu + + - name: Run default + run: ./test_llama3cu + + - name: Build BF16 precision + run: PRECISION=BF16 make train_llama3cu test_llama3cu + + - name: Run default + run: ./test_llama3cu + unit-tests-gpu: runs-on: ubicloud-gpu-standard-1-latest diff --git a/train_llama3.py b/train_llama3.py index 04d7ca1d8..2fc64a644 100644 --- a/train_llama3.py +++ b/train_llama3.py @@ -435,10 +435,16 @@ def unpermute(w, n_heads, dim1, dim2): return checkpoint @classmethod - def from_pretrained_llama3_hf(cls, model_id): + def from_pretrained_llama3_hf(cls, model_id, untie): """Loads pretrained LLaMA model weights from HuggingFace""" from transformers import AutoModelForCausalLM, AutoTokenizer model_args = MODEL_DICT[model_id] + if untie: + if not model_args.tied_embeddings: + print("Model embeddings are not tied, --untie has no effect.") + else: + print("Untying token embeddings and LM head.") + model_args.tied_embeddings = False model = AutoModelForCausalLM.from_pretrained(model_id) checkpoint = LLaMA.adapt_llama_state_dict_keys_hf(model.state_dict(), model_args) @@ -1040,6 +1046,7 @@ def print0(*args, **kwargs): parser.add_argument("--output_dir", type=str, default="", help="output directory to which to write logs and checkpoints") parser.add_argument("--model", type=str, default="meta-llama/Llama-3.2-1B", help="chose the llama model") parser.add_argument("--depth", type=int, default=-1, help="load only a subset of the model's layers") + parser.add_argument("--untie", type=int, default=False, help="Untie token embeddings and LM-head, even if they are tied in the checkpoint.") # token layout for each step of the optimization parser.add_argument("--batch_size", type=int, default=4, help="batch size, in units of #batch dimensions") parser.add_argument("--sequence_length", type=int, default=64, help="sequence length") @@ -1144,7 +1151,7 @@ def print0(*args, **kwargs): # init the model if args.use_hf: - model = LLaMA.from_pretrained_llama3_hf(args.model) + model = LLaMA.from_pretrained_llama3_hf(args.model, args.untie) else: # use Meta's checkpoint assert args.ckpt_dir is not None and os.path.exists(args.ckpt_dir), f"llama3 ckpt dir {args.ckpt_dir} does not exist" assert args.tokenizer_path is not None and os.path.exists(args.tokenizer_path), f"llama3 tokenizer path {args.tokenizer_path} does not exist"