diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 4840d30e4..8821fbcf1 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -234,3 +234,11 @@ jobs: - name: Build BF16 run: PRECISION=BF16 make test_llama3cu train_llama3cu + + - name: Get cudnn + run: | + apt-get update && apt-get install -y git + git clone https://github.com/NVIDIA/cudnn-frontend.git + + - name: Build cuDNN + run: PRECISION=BF16 USE_CUDNN=1 make test_llama3cu train_llama3cu diff --git a/.github/workflows/ci_gpu.yml b/.github/workflows/ci_gpu.yml index d52285c82..f545985a0 100644 --- a/.github/workflows/ci_gpu.yml +++ b/.github/workflows/ci_gpu.yml @@ -115,29 +115,29 @@ jobs: build-and-test-llama3: name: Build and test LLama3.2 1B runs-on: ubicloud-gpu-standard-1-latest - env: - HF_TOKEN: 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 + container: + image: nvidia/cuda:12.6.3-cudnn-devel-ubuntu22.04 steps: - name: Checkout code uses: actions/checkout@v4 - - run: echo "::add-mask::$HF_TOKEN" + - run: echo "::add-mask::$(echo us_xrYQGKBiJeqDMlTxkGhSgjelZKYbJHTgDY | tr 'A-Za-z' 'N-ZA-Mn-za-m')" - name: Install OpenMP - run: sudo apt-get update && sudo apt-get install -y libomp-dev + run: apt-get update && apt-get install -y libomp-dev libopenmpi-dev python3-pip - name: Install dependencies run: pip install -r requirements.txt - name: Run preprocessing - run: python dev/data/tinyshakespeare.py --model_desc llama-3 + run: HF_TOKEN=$(echo us_xrYQGKBiJeqDMlTxkGhSgjelZKYbJHTgDY | tr 'A-Za-z' 'N-ZA-Mn-za-m') python3 dev/data/tinyshakespeare.py --model_desc llama-3 - name: Train model # use the first 10 layers, so that everything fits into the 20GB of # the A4000 Ada that we have in CI - run: python train_llama3.py --write_tensors 1 --dtype float32 --depth 10 + run: HF_TOKEN=$(echo us_xrYQGKBiJeqDMlTxkGhSgjelZKYbJHTgDY | tr 'A-Za-z' 'N-ZA-Mn-za-m') python3 train_llama3.py --write_tensors 1 --dtype float32 --depth 10 - name: Build FP32 precision - run: PRECISION=FP32 make test_llama3cu + run: PRECISION=FP32 NO_MULTI_GPU=1 make test_llama3cu - name: Run default run: ./test_llama3cu @@ -149,7 +149,7 @@ jobs: run: ./test_llama3cu -r 2 - name: Build BF16 precision - run: PRECISION=BF16 make train_llama3cu test_llama3cu + run: PRECISION=BF16 NO_MULTI_GPU=1 make train_llama3cu test_llama3cu - name: Run default (BF16) run: ./test_llama3cu @@ -166,15 +166,12 @@ jobs: build-and-test-llama3-untied: name: Build and test LLama3.2 1B with untie weights runs-on: ubicloud-gpu-standard-1-latest - env: - HF_TOKEN: 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 steps: - name: Checkout code uses: actions/checkout@v4 - - run: echo "::add-mask::$HF_TOKEN" - name: Install OpenMP - run: sudo apt-get update && sudo apt-get install -y libomp-dev + run: sudo apt-get update && sudo apt-get install -y libomp-dev git - name: Install dependencies run: pip install -r requirements.txt @@ -197,6 +194,19 @@ jobs: - name: Run default run: ./test_llama3cu + - name: Install cuDNN-frontend + run: + git clone https://github.com/NVIDIA/cudnn-frontend.git + + - name: Build with cuDNN + run: USE_CUDNN=1 PRECISION=BF16 NO_MULTI_GPU=1 make train_llama3cu test_llama3cu + + - name: Train model with cuDNN + run: ./train_llama3cu + + - name: Execute testing program with cuDNN + run: ./test_llama3cu + unit-tests-gpu: runs-on: ubicloud-gpu-standard-1-latest diff --git a/dev/data/fineweb.py b/dev/data/fineweb.py index 12e1eb7c7..9948f3867 100644 --- a/dev/data/fineweb.py +++ b/dev/data/fineweb.py @@ -66,7 +66,7 @@ def tokenize_llama(doc): # tokenizes a single document and returns a numpy array of uint32 tokens - tokenizer = AutoTokenizer.from_pretrained("meta-llama/Meta-Llama-3.1-8B") + tokenizer = AutoTokenizer.from_pretrained("unsloth/Llama-3.2-1B") encode = lambda s: tokenizer.encode(s, add_special_tokens=False, verbose=False, split_special_tokens=True) eot = tokenizer.encode('')[0] # by default the tokenizer adds the EOT token (128000) tokens = [eot] # the special <|endoftext|> token delimits all documents diff --git a/dev/data/tinyshakespeare.py b/dev/data/tinyshakespeare.py index 758a9a478..4d9b3a06e 100644 --- a/dev/data/tinyshakespeare.py +++ b/dev/data/tinyshakespeare.py @@ -50,7 +50,7 @@ def tokenize(model_desc): encode = lambda s: enc.encode_ordinary(s) eot = enc._special_tokens['<|endoftext|>'] # end of text token elif model_desc == "llama-3": - tokenizer = AutoTokenizer.from_pretrained("meta-llama/Meta-Llama-3.1-8B") + tokenizer = AutoTokenizer.from_pretrained("unsloth/Llama-3.2-1B") encode = lambda s: tokenizer.encode(s, add_special_tokens=False, verbose=False, split_special_tokens=True) eot = tokenizer.encode('')[0] # by default the tokenizer adds the EOT token (128000) else: diff --git a/dev/data/tinystories.py b/dev/data/tinystories.py index 915c3218f..4a2ca82ab 100644 --- a/dev/data/tinystories.py +++ b/dev/data/tinystories.py @@ -76,7 +76,7 @@ def process_shard(shard_index, shard_filename, model_desc): encode = lambda s: enc.encode_ordinary(s) eot = enc._special_tokens['<|endoftext|>'] # end of text token elif model_desc == "llama-3": - tokenizer = AutoTokenizer.from_pretrained("meta-llama/Meta-Llama-3.1-8B") + tokenizer = AutoTokenizer.from_pretrained("unsloth/Llama-3.2-1B") encode = lambda s: tokenizer.encode(s, add_special_tokens=False, verbose=False, split_special_tokens=True) eot = tokenizer.encode('')[0] # by default the tokenizer adds the EOT token (128000) else: diff --git a/llmc/cudnn_att.cpp b/llmc/cudnn_att.cpp index 0330abe20..c34088c04 100644 --- a/llmc/cudnn_att.cpp +++ b/llmc/cudnn_att.cpp @@ -53,15 +53,15 @@ enum UIDs { }; // Need a cache because graph->build_operation_graph() is slow but everything else seems fast -using cache_type_fwd = std::map, std::shared_ptr>; -using cache_type_bwd = std::map, std::shared_ptr>; +using cache_type_fwd = std::map, std::shared_ptr>; +using cache_type_bwd = std::map, std::shared_ptr>; // Loosely based on cuDNN frontend samples functions and massively simplified -auto lookup_cache_or_build_graph_fwd(int B,int H,int T,int HS, int is_inference_only) { +auto lookup_cache_or_build_graph_fwd(int B, int Hq, int Hkv, int T, int HS, int is_inference_only) { static cache_type_fwd user_maintained_cache_fwd; - auto key = std::make_tuple(B, H, T, HS, is_inference_only); + auto key = std::make_tuple(B, Hq, Hkv, T, HS, is_inference_only); auto it = user_maintained_cache_fwd.find(key); if (it != user_maintained_cache_fwd.end()) { @@ -74,18 +74,21 @@ auto lookup_cache_or_build_graph_fwd(int B,int H,int T,int HS, int is_inference_ .set_compute_data_type(fe::DataType_t::FLOAT); // QKV is (B, T, 3, NH, HS) which cuDNN can handle directly without an external permute + // for (B, N, (NH + 2*(NH/replicate_factor)) * HS) + // (B, T, Hq + 2Hkv, HS) + int H = Hq + 2 * Hkv; auto Q = graph->tensor(fe::graph::Tensor_attributes().set_name("Q") - .set_dim({B, H, T, HS}) + .set_dim({B, Hq, T, HS}) .set_uid(Q_UID) - .set_stride({3 * H * HS * T, HS, 3 * H * HS, 1})); + .set_stride({H * HS * T, HS, H * HS, 1})); auto K = graph->tensor(fe::graph::Tensor_attributes().set_name("K") - .set_dim({B, H, T, HS}) + .set_dim({B, Hkv, T, HS}) .set_uid(K_UID) - .set_stride({3 * H * HS * T, HS, 3 * H * HS, 1})); + .set_stride({H * HS * T, HS, H * HS, 1})); auto V = graph->tensor(fe::graph::Tensor_attributes().set_name("V") - .set_dim({B, H, T, HS}) + .set_dim({B, Hkv, T, HS}) .set_uid(V_UID) - .set_stride({3 * H * HS * T, HS, 3 * H * HS, 1})); + .set_stride({H * HS * T, HS, H * HS, 1})); auto attn_scale = graph->tensor(fe::graph::Tensor_attributes().set_name("attn_scale") .set_dim({1, 1, 1, 1}) .set_stride({1, 1, 1, 1}) @@ -101,14 +104,14 @@ auto lookup_cache_or_build_graph_fwd(int B,int H,int T,int HS, int is_inference_ // Create the graph operation and get the output tensors back auto [O, stats] = graph->sdpa(Q, K, V, sdpa_options); - // Output is (B, T, NH, HS) BF16/FP16 and stats for backward pass is (B, NH, T) FP32 - O->set_output(true).set_dim({B, H, T, HS}).set_stride({H * HS * T, HS, H * HS, 1}).set_uid(O_UID); + // Output is (B, T, Hq, HS) BF16/FP16 and stats for backward pass is (B, Hq, T) FP32 + O->set_output(true).set_dim({B, Hq, T, HS}).set_stride({Hq * HS * T, HS, Hq * HS, 1}).set_uid(O_UID); assert(stats == nullptr || is_inference_only == false); if (is_inference_only == false) { stats->set_output(true).set_data_type(fe::DataType_t::FLOAT) - .set_dim({B, H, T, 1}) - .set_stride({H * T, T, 1, 1}) + .set_dim({B, Hq, T, 1}) + .set_stride({Hq * T, T, 1, 1}) .set_uid(Stats_UID); } @@ -134,10 +137,10 @@ auto lookup_cache_or_build_graph_fwd(int B,int H,int T,int HS, int is_inference_ return graph; } -auto lookup_cache_or_build_graph_bwd(int B, int NH, int T, int HS) { +auto lookup_cache_or_build_graph_bwd(int B, int Hq, int Hkv, int T, int HS) { static cache_type_bwd user_maintained_cache_bwd; - auto key = std::make_tuple(B, NH, T, HS); + auto key = std::make_tuple(B, Hq, Hkv, T, HS); auto it = user_maintained_cache_bwd.find(key); if (it != user_maintained_cache_bwd.end()) { @@ -151,31 +154,32 @@ auto lookup_cache_or_build_graph_bwd(int B, int NH, int T, int HS) { // (B, N, 3, NH, HS) // must come from inp (which means we also need to convert THAT to FP16) + int H = Hq + 2*Hkv; auto Q = graph->tensor(fe::graph::Tensor_attributes().set_name("Q") - .set_dim({B, NH, T, HS}) + .set_dim({B, Hq, T, HS}) .set_uid(Q_UID) - .set_stride({3 * NH * HS * T, HS, 3 * NH * HS, 1})); + .set_stride({H * HS * T, HS, H * HS, 1})); auto K = graph->tensor(fe::graph::Tensor_attributes().set_name("K") - .set_dim({B, NH, T, HS}) + .set_dim({B, Hkv, T, HS}) .set_uid(K_UID) - .set_stride({3 * NH * HS * T, HS, 3 * NH * HS, 1})); + .set_stride({H * HS * T, HS, H * HS, 1})); auto V = graph->tensor(fe::graph::Tensor_attributes().set_name("V") - .set_dim({B, NH, T, HS}) + .set_dim({B, Hkv, T, HS}) .set_uid(V_UID) - .set_stride({3 * NH * HS * T, HS, 3 * NH * HS, 1})); + .set_stride({H * HS * T, HS, H * HS, 1})); auto O = graph->tensor(fe::graph::Tensor_attributes().set_name("O") - .set_dim({B, NH, T, HS}) + .set_dim({B, Hq, T, HS}) .set_uid(O_UID) - .set_stride({NH * HS * T, HS, NH * HS, 1})); + .set_stride({Hq * HS * T, HS, Hq * HS, 1})); auto dO = graph->tensor(fe::graph::Tensor_attributes().set_name("dO") - .set_dim({B, NH, T, HS}) + .set_dim({B, Hq, T, HS}) .set_uid(dO_UID) - .set_stride({NH * HS * T, HS, NH * HS, 1})); + .set_stride({Hq * HS * T, HS, Hq * HS, 1})); auto stats = graph->tensor(fe::graph::Tensor_attributes().set_name("stats") - .set_dim({B, NH, T, 1}) + .set_dim({B, Hq, T, 1}) .set_uid(Stats_UID) - .set_stride({NH * T, T, 1, 1}) + .set_stride({Hq * T, T, 1, 1}) .set_data_type(fe::DataType_t::FLOAT)); auto attn_scale = graph->tensor(fe::graph::Tensor_attributes().set_name("attn_scale") .set_dim({1, 1, 1, 1}) @@ -193,9 +197,9 @@ auto lookup_cache_or_build_graph_bwd(int B, int NH, int T, int HS) { // Create the graph operation and get the output tensors back auto [dQ, dK, dV] = graph->sdpa_backward(Q, K, V, O, dO, stats, sdpa_backward_options); - dQ->set_output(true).set_dim({B, NH, T, HS}).set_stride({3 * NH * HS * T, HS, 3 * NH * HS, 1}).set_uid(dQ_UID); - dK->set_output(true).set_dim({B, NH, T, HS}).set_stride({3 * NH * HS * T, HS, 3 * NH * HS, 1}).set_uid(dK_UID); - dV->set_output(true).set_dim({B, NH, T, HS}).set_stride({3 * NH * HS * T, HS, 3 * NH * HS, 1}).set_uid(dV_UID); + dQ->set_output(true).set_dim({B, Hq, T, HS}).set_stride({H * HS * T, HS, H * HS, 1}).set_uid(dQ_UID); + dK->set_output(true).set_dim({B, Hkv, T, HS}).set_stride({H * HS * T, HS, H * HS, 1}).set_uid(dK_UID); + dV->set_output(true).set_dim({B, Hkv, T, HS}).set_stride({H * HS * T, HS, H * HS, 1}).set_uid(dV_UID); checkCudnnFE(graph->validate()); @@ -219,23 +223,22 @@ auto lookup_cache_or_build_graph_bwd(int B, int NH, int T, int HS) { return graph; } -void attention_forward_cudnn(floatX* out, // output: (B, T, NH, HS) - float* stats, // output for backward pass: (B, NH, T) - floatX* inp, // input: (B, T, 3, NH, HS) QKV - int B, int T, int NH, int C, cudaStream_t stream) { +void attention_forward_cudnn(floatX* out, // output: (B, T, Hq, HS) + float* stats, // output for backward pass: (B, Hq, T) + floatX* inp, // input: (B, T, Hq + Hk + Hv, HS) QKV + int B, int T, int Hq, int Hkv, int HS, cudaStream_t stream) { NVTX_RANGE_FN(); - int HS = C / NH; // number of features per head bool is_inference_only = (stats == nullptr); cuDNNCheck(cudnnSetStream(cudnn_handle, stream)); // Get graph and tensors from cache (or generate it on first use) - auto graph = lookup_cache_or_build_graph_fwd(B, NH, T, HS, is_inference_only); + auto graph = lookup_cache_or_build_graph_fwd(B, Hq, Hkv, T, HS, is_inference_only); // Prepare all the tensor pointers for executing the graph void* devPtrQ = inp; - void* devPtrK = (inp + C); - void* devPtrV = (inp + 2 * C); + void* devPtrK = (inp + Hq * HS); + void* devPtrV = (inp + (Hq + Hkv) * HS); float attn_scale_cpu = 1.0 / sqrtf(HS); void* devPtrO = out; @@ -255,25 +258,24 @@ void attention_forward_cudnn(floatX* out, // output: (B, T, NH, HS) void attention_backward_cudnn(floatX* dqkvr, // output floatX* dout, floatX* qkvr, floatX* o, float* stats, // inputs - int B, int T, int NH, int C, cudaStream_t stream) { + int B, int T, int Hq, int Hkv, int HS, cudaStream_t stream) { NVTX_RANGE_FN(); - int HS = C / NH; // number of features per head // Get graph and tensors from cache (or generate it on first use) - auto graph = lookup_cache_or_build_graph_bwd(B, NH, T, HS); + auto graph = lookup_cache_or_build_graph_bwd(B, Hq, Hkv, T, HS); // Prepare all the tensor pointers for executing the graph void* devPtrQ = qkvr; - void* devPtrK = (qkvr + NH * HS); - void* devPtrV = (qkvr + 2 * NH * HS); + void* devPtrK = (qkvr + Hq * HS); + void* devPtrV = (qkvr + (Hq + Hkv) * HS); void* devPtrO = o; void* devPtrdO = dout; void* devPtrStats = stats; float attn_scale_cpu = 1.0 / sqrtf(HS); void* devPtrdQ = dqkvr; - void* devPtrdK = (dqkvr + NH * HS); - void* devPtrdV = (dqkvr + 2 * NH * HS); + void* devPtrdK = (dqkvr + Hq * HS); + void* devPtrdV = (dqkvr + (Hq + Hkv) * HS); // Build variant pack that links each tensor to its data pointer std::unordered_map variant_pack = { diff --git a/llmc/cudnn_att.h b/llmc/cudnn_att.h index 318413007..09db61a78 100644 --- a/llmc/cudnn_att.h +++ b/llmc/cudnn_att.h @@ -9,13 +9,13 @@ cuDNN (flash) attention // forward declarations of functions defined in cudnn_att.cpp void create_cudnn(); void destroy_cudnn(); -void attention_forward_cudnn(floatX* out, // output: (B, T, NH, HS) - float* stats, // output for backward pass: (B, NH, T) - floatX* inp, // input: (B, T, 3, NH, HS) QKV - int B, int T, int NH, int C, cudaStream_t stream); +void attention_forward_cudnn(floatX* out, // output: (B, T, Nq, HS) + float* stats, // output for backward pass: (B, Hq, T) + floatX* inp, // input: (B, T, Hq + 2Hkv, HS) QKV + int B, int T, int Hq, int Hkv, int HS, cudaStream_t stream); void attention_backward_cudnn(floatX* dqkvr, // output floatX* dout, floatX* qkvr, floatX* o, float* stats, // inputs - int B, int T, int NH, int C, cudaStream_t stream); + int B, int T, int Hq, int Hkv, int HS, cudaStream_t stream); #endif // CUDNN_ATT_H \ No newline at end of file diff --git a/llmc/rope.cuh b/llmc/rope.cuh index 50371c47b..3bf5672e3 100644 --- a/llmc/rope.cuh +++ b/llmc/rope.cuh @@ -43,20 +43,27 @@ void precompute_freqs_cis(floatX *freqs_cis, int dim, int end, float theta, int } } -__global__ void rope_forward_kernel1(floatX *out, const floatX *inp, const floatX *freqs_cis, int B, int T, int n_head, int head_dim) { +__global__ void rope_forward_kernel1(floatX *out, const floatX *inp, const floatX *freqs_cis, int B, int T, int Nq, int Nkv, int head_dim) { int idx = blockIdx.x * blockDim.x + threadIdx.x; int head_dim_half = head_dim / 2; - if (idx >= B * T * 3 * n_head * head_dim_half) return; + int N = Nq + 2*Nkv; + if (idx >= B * T * N * head_dim_half) return; // decode the qkv index early so we can early exit if it's a value index - int qkv = (idx / (n_head * head_dim_half)) % 3; + int h = (idx / head_dim_half) % N; + int qkv = 2; + if(h < Nq) { + qkv = 0; // query head + } else if (h < Nq + Nkv) { + qkv = 1; // key head + h -= Nq; + } if (qkv == 2) return; // no-op for v // decode the individual indices and get the input index - int b = idx / (T * 3 * n_head * head_dim_half); - int t = (idx / (3 * n_head * head_dim_half)) % T; - int h = (idx / head_dim_half) % n_head; + int b = idx / (T * N * head_dim_half); + int t = (idx / (N * head_dim_half)) % T; int d = idx % head_dim_half; - int idx_bt = b * (T * 3 * n_head * head_dim) + t * (3 * n_head * head_dim); - int idx_bth = idx_bt + qkv * (n_head * head_dim) + h * head_dim; + int idx_bt = b * (T * N * head_dim) + t * (N * head_dim); + int idx_bth = idx_bt + qkv * (Nq * head_dim) + h * head_dim; int idxi = idx_bth + 2 * d; // index in the input // fetch the freqs_cis int freqs_idx = t * head_dim + 2 * d; @@ -70,20 +77,27 @@ __global__ void rope_forward_kernel1(floatX *out, const floatX *inp, const float out[idxi + 1] = x_real * freqs_sin + x_imag * freqs_cos; } -__global__ void rope_backward_inplace_kernel1(floatX *dinp, const floatX *dout, const floatX *freqs_cis, int B, int T, int n_head, int head_dim) { +__global__ void rope_backward_kernel1(floatX *dinp, const floatX *dout, const floatX *freqs_cis, int B, int T, int Nq, int Nkv, int head_dim) { int idx = blockIdx.x * blockDim.x + threadIdx.x; int head_dim_half = head_dim / 2; - if (idx >= B * T * 3 * n_head * head_dim_half) return; + int N = Nq + 2*Nkv; + if (idx >= B * T * N * head_dim_half) return; // decode the qkv index early so we can early exit if it's a value index - int qkv = (idx / (n_head * head_dim_half)) % 3; + int h = (idx / head_dim_half) % N; + int qkv = 2; + if(h < Nq) { + qkv = 0; // query head + } else if (h < Nq + Nkv) { + qkv = 1; // key head + h -= Nq; + } if (qkv == 2) return; // no-op for v // decode the individual indices and get the input index - int b = idx / (T * 3 * n_head * head_dim_half); - int t = (idx / (3 * n_head * head_dim_half)) % T; - int h = (idx / head_dim_half) % n_head; + int b = idx / (T * N * head_dim_half); + int t = (idx / (N * head_dim_half)) % T; int d = idx % head_dim_half; - int idx_bt = b * (T * 3 * n_head * head_dim) + t * (3 * n_head * head_dim); - int idx_bth = idx_bt + qkv * (n_head * head_dim) + h * head_dim; + int idx_bt = b * (T * N * head_dim) + t * (N * head_dim); + int idx_bth = idx_bt + qkv * (Nq * head_dim) + h * head_dim; int idxi = idx_bth + 2 * d; // index in the input // fetch the freqs_cis int freqs_idx = t * head_dim + 2 * d; @@ -96,23 +110,23 @@ __global__ void rope_backward_inplace_kernel1(floatX *dinp, const floatX *dout, dinp[idxi + 1] = -dout_real * freqs_sin + dout_imag * freqs_cos; } -void rope_forward(floatX *out, const floatX *inp, const floatX *freqs_cis, int B, int T, int n_head, int head_dim, cudaStream_t stream) { - // the input and output to this kernel are (B, T, 3, NH, HD) where the 3 is q,k,v +void rope_forward(floatX *out, const floatX *inp, const floatX *freqs_cis, int B, int T, int Nq, int Nkv, int head_dim, cudaStream_t stream) { + // the input and output to this kernel are (B, T, Nq + Nk + Nv, HD) // we are going to launch exactly one thread per element of the output, // except divide by two because the work is in "tuples" // so this single kernel launch will do RoPE for both q and k, and the threads for v will be a no-op const int block_size = 128; - int total_threads = B * T * 3 * n_head * head_dim / 2; + int total_threads = B * T * (Nq + 2*Nkv) * head_dim / 2; int num_blocks = CEIL_DIV(total_threads, block_size); - rope_forward_kernel1<<>>(out, inp, freqs_cis, B, T, n_head, head_dim); + rope_forward_kernel1<<>>(out, inp, freqs_cis, B, T, Nq, Nkv, head_dim); cudaCheck(cudaGetLastError()); } -void rope_backward_inplace(floatX *dinp, const floatX *dout, const floatX *freqs_cis, int B, int T, int n_head, int head_dim, cudaStream_t stream) { +void rope_backward_inplace(floatX *dinp, const floatX *dout, const floatX *freqs_cis, int B, int T, int Nq, int Nkv, int head_dim, cudaStream_t stream) { // backward pass of forward, mirrors the forward kernel in setup and indexing const int block_size = 128; - int total_threads = B * T * 3 * n_head * head_dim / 2; + int total_threads = B * T * (Nq + 2*Nkv) * head_dim / 2; int num_blocks = CEIL_DIV(total_threads, block_size); - rope_backward_inplace_kernel1<<>>(dinp, dout, freqs_cis, B, T, n_head, head_dim); + rope_backward_kernel1<<>>(dinp, dout, freqs_cis, B, T, Nq, Nkv, head_dim); cudaCheck(cudaGetLastError()); } diff --git a/train_gpt2.cu b/train_gpt2.cu index d1af34f21..6ff844cda 100644 --- a/train_gpt2.cu +++ b/train_gpt2.cu @@ -718,7 +718,7 @@ void gpt2_forward(GPT2 *model, const int* inputs, size_t B, size_t T) { #ifdef ENABLE_CUDNN float* l_att = (float*)acts.att + l * B * NH * T; // cuDNN needs a smaller FP32 tensor matmul_forward_cublaslt(l_qkvr, l_ln1, l_qkvw, l_qkvb, B, T, C, 3*C, main_stream); - attention_forward_cudnn(l_atty, (float*)l_att, l_qkvr, B, T, NH, C, main_stream); + attention_forward_cudnn(l_atty, (float*)l_att, l_qkvr, B, T, NH, NH, C / NH, main_stream); #else floatX* l_att = acts.att + l * B * NH * T * T; if (T != model->seq_len) { // unused parts of attention buffer must be zeroed (T-dependent) @@ -909,7 +909,7 @@ void gpt2_backward_and_reduce(GPT2 *model, int* inputs, const int* targets, int #ifdef ENABLE_CUDNN float* l_att = (float*)acts.att + l * B * NH * T; // cuDNN needs a smaller FP32 tensor - attention_backward_cudnn(dl_bt4c, dl_btc, l_qkvr, l_atty, (float*)l_att, B, T, NH, C, main_stream); + attention_backward_cudnn(dl_bt4c, dl_btc, l_qkvr, l_atty, (float*)l_att, B, T, NH, NH, C / NH, main_stream); #else floatX* l_att = acts.att + l * B * NH * T * T; // we need B x T x (4)C buffers. l_atty and l_fch aren't needed anymore at this point, so reuse their memory diff --git a/train_llama3.cu b/train_llama3.cu index 7ea5a6e5d..252d2583f 100644 --- a/train_llama3.cu +++ b/train_llama3.cu @@ -712,25 +712,28 @@ void llama3_forward(LLama3 *model, const int* inputs, size_t B, size_t T) { floatX* l_fch_swiglu = (model->recompute < 1) ? acts.fch_gelu + l * B * T * ffn_channels_post_gelu : acts.fch_gelu; floatX* l_residual3 = acts.residual3 + l * B * T * C; floatX* scratch = (floatX*)acts.output; // used for non-cudnn attention, fcproj, attproj, etc. - floatX* qkv_rep_scratch = (floatX*)acts.scratch_bt4c; // we can use the BT4C scratch for qkv replication // Attention block // The input l_ln1 now holds the (already layernormed) input #ifdef ENABLE_CUDNN - printf("cuDNN path TODO\n"); exit(0); + // 1) projection to QKV vectors (note k,v may be fewer heads than q) matmul_forward_cublaslt(l_qkvr, l_ln1, l_qkvw, l_qkvb, B, T, C, qkv_channels, main_stream); + // 2) apply RoPE to q,k in place + rope_forward(l_qkvr, l_qkvr, model->freqs_cis, B, T, n_head, n_kv_head, hd, main_stream); + // 4) attention: att <- softmax(qk^T)v float* l_att = (float*)acts.att + l * B * NH * T; // cuDNN needs a smaller FP32 tensor - attention_forward_cudnn(l_atty, (float*)l_att, l_qkvr, B, T, NH, C, main_stream); + attention_forward_cudnn(l_atty, (float*)l_att, l_qkvr, B, T, NH, n_kv_head, hd, main_stream); #else // unused parts of attention buffer must be zeroed (T-dependent) floatX* l_att = acts.att + l * B * NH * T * T; + floatX* qkv_rep_scratch = (floatX*)acts.scratch_bt4c; // we can use the BT4C scratch for qkv replication if (T != model->seq_len) { cudaCheck(cudaMemset(l_att, 0, B * NH * T * T * sizeof(floatX))); } // 1) projection to QKV vectors (note k,v may be fewer heads than q) matmul_forward_cublaslt(scratch, l_ln1, l_qkvw, l_qkvb, B, T, C, qkv_channels, main_stream); - // 2) replicate k,v so that all of q,k,v have the same number of heads. done for simplicity, for now + // 2) apply RoPE to q,k in place + rope_forward(scratch, scratch, model->freqs_cis, B, T, n_head, n_kv_head, hd, main_stream); + // 3) replicate k,v so that all of q,k,v have the same number of heads. done for simplicity, for now repkv_forward(qkv_rep_scratch, scratch, B, T, n_head, n_kv_head, hd, main_stream); - // 3) apply RoPE to q,k in place - rope_forward(qkv_rep_scratch, qkv_rep_scratch, model->freqs_cis, B, T, n_head, hd, main_stream); // 4) attention: att <- softmax(qk^T)v attention_forward(l_atty, l_qkvr, l_att, qkv_rep_scratch, B, T, C, NH, main_stream); #endif @@ -927,20 +930,19 @@ void llama3_backward_and_reduce(LLama3 *model, int* inputs, const int* targets, // <--- gradient here matches OK #ifdef ENABLE_CUDNN - printf("cuDNN path TODO\n"); exit(0); float* l_att = (float*)acts.att + l * B * NH * T; // cuDNN needs a smaller FP32 tensor - attention_backward_cudnn(dl_bt4c, dl_btc, l_qkvr, l_atty, (float*)l_att, B, T, NH, C, main_stream); + attention_backward_cudnn(dl_bt4c2, dl_btc, l_qkvr, l_atty, l_att, B, T, NH, n_kv_head, hd, main_stream); #else floatX* l_att = acts.att + l * B * NH * T * T; // we need B x T x (4)C buffers. l_atty and l_fch aren't needed anymore at this point, so reuse their memory floatX* buffer_a = l_atty; floatX* buffer_b = l_fch_pre_gelu; // this is B x T x 4C, so even larger than what we need attention_backward(dl_bt4c, buffer_b, scratchX, buffer_a, dl_btc, l_qkvr, l_att, B, T, C, NH, main_stream); - #endif - // backward rope (this can be done in-place) - rope_backward_inplace(dl_bt4c, dl_bt4c, model->freqs_cis, B, T, NH, hd, main_stream); // backward repkv (use scratchX as gradient buffer here) repkv_backward(dl_bt4c2, dl_bt4c, B, T, NH, n_kv_head, hd, main_stream); + #endif + // backward rope (this can be done in-place) + rope_backward_inplace(dl_bt4c2, dl_bt4c2, model->freqs_cis, B, T, NH, n_kv_head, hd, main_stream); // backward QKV projection if(model->recompute >= 2) { rmsnorm_forward(l_ln1, l_ln1_rstd, residual, l_ln1w, B, T, C, main_stream); diff --git a/train_llama3.py b/train_llama3.py index 2fc64a644..ca0171339 100644 --- a/train_llama3.py +++ b/train_llama3.py @@ -297,6 +297,7 @@ def __post_init__(self): "meta-llama/Meta-Llama-3.1-8B": LLama3_8BConfig, "meta-llama/Llama-3.2-3B": LLama3_3BConfig, "meta-llama/Llama-3.2-1B": LLama3_1BConfig, + "unsloth/Llama-3.2-1B": LLama3_1BConfig, } @@ -1044,7 +1045,7 @@ def print0(*args, **kwargs): parser.add_argument("--input_bin", type=str, default="dev/data/tinyshakespeare/tiny_shakespeare_val.bin", help="input .bin to train on") parser.add_argument("--input_val_bin", type=str, default="", help="input .bin to eval validation loss on") parser.add_argument("--output_dir", type=str, default="", help="output directory to which to write logs and checkpoints") - parser.add_argument("--model", type=str, default="meta-llama/Llama-3.2-1B", help="chose the llama model") + parser.add_argument("--model", type=str, default="unsloth/Llama-3.2-1B", help="chose the llama model") parser.add_argument("--depth", type=int, default=-1, help="load only a subset of the model's layers") parser.add_argument("--untie", type=int, default=False, help="Untie token embeddings and LM-head, even if they are tied in the checkpoint.") # token layout for each step of the optimization