Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
75 commits
Select commit Hold shift + click to select a range
09b47a7
llama3 starting point is at gpt-2 exact copy paste for both train/tes…
karpathy Sep 13, 2024
01bc4c6
first set of changes to match up the .py and the .cu version. default…
karpathy Sep 13, 2024
b883560
change the export code of Llama 3 to be very GPT-2 friendly, using a …
karpathy Sep 13, 2024
8866308
adapt the sizes of all the parameter tensors and load them from file.…
karpathy Sep 16, 2024
45026f6
make llama3cu phony
karpathy Sep 16, 2024
77e1d7a
add support for dataloader to serve uint32_t tokens, as necessary in …
karpathy Sep 16, 2024
72e6f1a
add new Encoder that does not use positional embeddings, like in llam…
karpathy Sep 16, 2024
234de31
introduce rmsnorm, unfused, forward
karpathy Sep 16, 2024
508c474
move debugging into fp32, so python has to write the fp32 version, an…
karpathy Sep 17, 2024
685617f
make fp32 path in .py code work correctly
karpathy Sep 17, 2024
56f956c
add repkv kernel to replicate K,V heads after the QKV projection
karpathy Sep 21, 2024
45401b4
DRAFT: Adding backward kernel for repkv
insop Sep 22, 2024
080e57f
CPU version tested
insop Sep 22, 2024
6c68657
Put cuda kernel caller placeholder
insop Sep 22, 2024
ad46043
WIP updating cuda kernel
insop Sep 22, 2024
42d09e8
minor clean up
insop Sep 22, 2024
fcc3466
Add minor change
insop Sep 22, 2024
de9c817
wip
insop Sep 24, 2024
76b40e4
integrate the repkv kernel with minor changes. use the bt4c buffer fo…
karpathy Sep 24, 2024
026e4ed
add RoPE PyTorch and C reference code
karpathy Sep 24, 2024
8336d2a
Merge remote-tracking branch 'upstream/llama3' into insop/llama3
insop Sep 25, 2024
2ebf8f6
Add rmsnorm fused kernel
gordicaleksa Sep 25, 2024
52c7254
add the finished RoPE forward pass
karpathy Sep 25, 2024
6538df6
Merge pull request #769 from gordicaleksa/fused_rmsnorm
karpathy Sep 25, 2024
bb3c92d
integrate the fused rmsnorm forward
karpathy Sep 25, 2024
1826752
add swigul yaygit add -u!
karpathy Sep 25, 2024
0731b39
forward pass matchesgit add train_llama3.cu train_llama3.py ! losses …
karpathy Sep 25, 2024
8874c2c
Merge remote-tracking branch 'upstream/llama3' into insop/llama3
insop Sep 25, 2024
3e5134d
Merge branch 'insop/llama3_wip' into insop/llama3
insop Sep 25, 2024
d1f2f64
Updated repkv_backward cuda kernel
insop Sep 26, 2024
31be5e7
add rmsnorm backward in dev/cuda, it seems to work surprisingly, and …
karpathy Sep 26, 2024
a2b66f1
Merge remote-tracking branch 'upstream/llama3' into insop/llama3
insop Sep 26, 2024
102067f
oops i think i accidentally forgot to include swiglu.cuh
karpathy Sep 26, 2024
2c4b3cc
integrate our rmsnorm backward and move the other rmsnorm functions i…
karpathy Sep 26, 2024
cbf53e3
Merge remote-tracking branch 'upstream/llama3' into insop/llama3
insop Sep 26, 2024
01c2895
Update RoPE naming
insop Sep 26, 2024
1b54612
i can backward through MLP block. Attention block is next
karpathy Sep 27, 2024
c8b348e
Merge pull request #764 from insop/insop/llama3
karpathy Sep 27, 2024
28e4a7f
small fixes, but still not too happy with this kernel, it wastes thre…
karpathy Sep 27, 2024
075e430
just pushing what i have. it's epsilon away from working sigh. basica…
karpathy Sep 27, 2024
8d49062
add backward kernel to dev/cuda for rope, to ensure correctness. but …
karpathy Sep 27, 2024
7d945e9
reshuffle repkv a bit, i wrote it from scratch. the kernel is still c…
karpathy Sep 27, 2024
e6481b6
fix bug with qkvr sizing, has to be 3*C. Credit to @ademeure for find…
karpathy Oct 1, 2024
9099a0a
ok the full backward now shows max abs diff of 3e-3, except for the e…
karpathy Oct 1, 2024
c746e06
take out debugging stuff. we can now run training loop for both model…
karpathy Oct 1, 2024
2602b46
BF16 opt state (m/v) with stochastic rounding, seems to work really w…
ademeure Oct 1, 2024
d808d78
Merge pull request #772 from ademeure/llama3_arun_new
karpathy Oct 1, 2024
2c5ced6
fix bug due to bf16 adamw mv
karpathy Oct 1, 2024
3745dac
define llama3.2 1B and 3B for export from python (will untie embeddin…
ngc92 Apr 13, 2025
4d7980c
renaming gpt2 -> llama3
ngc92 Apr 13, 2025
090341e
enable llama3 CI
ngc92 Apr 13, 2025
a94471c
use optimizer offloading when running in CI
ngc92 Apr 14, 2025
4983c46
fix: fully ignore biases
ngc92 Apr 13, 2025
6866623
fix: match pytorch learning rate in test file
ngc92 Apr 13, 2025
2c3fecc
fix: gradient checking
ngc92 Apr 13, 2025
24d9129
fix: ensure `freqs_cis` are not broken when calling `model.to(dtype)`…
ngc92 Apr 14, 2025
f8a43ce
fix: writing checkpoint
ngc92 Apr 14, 2025
5b92829
!! DROP THIS COMMIT !!
ngc92 Apr 13, 2025
9c52a95
fix: CPUOffloadOptimizer + gradient clipping is broken; we use an ine…
ngc92 Apr 14, 2025
49cef1d
Merge pull request #802 from ngc92/ngc92/llama3-dev
karpathy May 1, 2025
1c02d54
cudnn does not support fp32 -> remove this pointless test
ngc92 May 1, 2025
7b7d39c
include grad norm in logging
ngc92 May 2, 2025
d4347a7
ensure 32-bit master params in python training
ngc92 May 2, 2025
082d9fa
added missing stream argument for repkv_backward
ngc92 May 2, 2025
a860922
set stream for attention softmax
ngc92 May 4, 2025
f38eadc
allow reducing number of transformer blocks to make smaller models th…
ngc92 May 4, 2025
35e1ad6
enable storing the expected loss values in the state file, so we can …
ngc92 May 4, 2025
76a7cce
replace offload with smaller model
ngc92 May 4, 2025
d36f0e6
Merge pull request #811 from ngc92/llama-fixes
karpathy May 10, 2025
9c60616
fix out-of-bounds access in rmsnorm kernel
ngc92 Jun 26, 2025
ffcfe99
fix out-of-bounds access in rmsnorm kernel
ngc92 Jun 26, 2025
9688eef
enable tied embeddings
ngc92 Apr 14, 2025
9caeceb
command-line overwrite to forcibly untie embeddings for llama3.2 models
ngc92 Apr 14, 2025
5c17e4e
Merge pull request #821 from ngc92/out-of-bounds-bugfix
karpathy Jun 26, 2025
da88cb1
Merge pull request #803 from ngc92/ngc92/llama3-tied-weights
karpathy Jun 26, 2025
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 15 additions & 0 deletions .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -219,3 +219,18 @@ jobs:

- name: Build project
run: make -j4 -C dev/cuda

build-llama3:
runs-on: ubuntu-latest
container:
image: nvidia/cuda:12.4.1-devel-ubuntu22.04

steps:
- name: Checkout code
uses: actions/checkout@v4

- name: Build FP32
run: PRECISION=FP32 make test_llama3cu train_llama3cu

- name: Build BF16
run: PRECISION=BF16 make test_llama3cu train_llama3cu
95 changes: 88 additions & 7 deletions .github/workflows/ci_gpu.yml
Original file line number Diff line number Diff line change
Expand Up @@ -9,9 +9,10 @@ on:
pull_request:
branches:
- master
- llama3

jobs:
build-and-test-gpu:
build-and-test-gpt2:
runs-on: ubicloud-gpu-standard-1-latest

steps:
Expand Down Expand Up @@ -103,19 +104,98 @@ jobs:
git clone https://github.com/NVIDIA/cudnn-frontend.git

- name: Build with cuDNN
run: USE_CUDNN=1 make test_gpt2cu train_gpt2cu test_gpt2fp32cu train_gpt2fp32cu
run: USE_CUDNN=1 make test_gpt2cu train_gpt2cu

- name: Train model with cuDNN
run: ./train_gpt2cu

- name: Train model fp32 with cuDNN
run: ./train_gpt2fp32cu

- name: Execute testing program with cuDNN
run: ./test_gpt2cu

- name: Execute testing program fp32 with cuDNN
run: ./test_gpt2fp32cu
build-and-test-llama3:
name: Build and test LLama3.2 1B
runs-on: ubicloud-gpu-standard-1-latest
env:
HF_TOKEN: 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
steps:
- name: Checkout code
uses: actions/checkout@v4
- run: echo "::add-mask::$HF_TOKEN"

- name: Install OpenMP
run: sudo apt-get update && sudo apt-get install -y libomp-dev

- name: Install dependencies
run: pip install -r requirements.txt

- name: Run preprocessing
run: python dev/data/tinyshakespeare.py --model_desc llama-3

- name: Train model
# use the first 10 layers, so that everything fits into the 20GB of
# the A4000 Ada that we have in CI
run: python train_llama3.py --write_tensors 1 --dtype float32 --depth 10

- name: Build FP32 precision
run: PRECISION=FP32 make test_llama3cu

- name: Run default
run: ./test_llama3cu

- name: Run no recompute GeLU
run: ./test_llama3cu -r 0

- name: Run recompute LN
run: ./test_llama3cu -r 2

- name: Build BF16 precision
run: PRECISION=BF16 make train_llama3cu test_llama3cu

- name: Run default (BF16)
run: ./test_llama3cu

- name: Run no recompute GeLU (BF16)
run: ./test_llama3cu -r 0

- name: Run no master weights (BF16)
run: ./test_llama3cu -w 0

- name: Run recompute LN (BF16)
run: ./test_llama3cu -r 2

build-and-test-llama3-untied:
name: Build and test LLama3.2 1B with untie weights
runs-on: ubicloud-gpu-standard-1-latest
env:
HF_TOKEN: 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
steps:
- name: Checkout code
uses: actions/checkout@v4
- run: echo "::add-mask::$HF_TOKEN"

- name: Install OpenMP
run: sudo apt-get update && sudo apt-get install -y libomp-dev

- name: Install dependencies
run: pip install -r requirements.txt

- name: Run preprocessing
run: python dev/data/tinyshakespeare.py --model_desc llama-3

- name: Train model
run: python train_llama3.py --write_tensors 1 --dtype float32 --untie 1 --depth 10

- name: Build FP32 precision
run: PRECISION=FP32 make test_llama3cu

- name: Run default
run: ./test_llama3cu

- name: Build BF16 precision
run: PRECISION=BF16 make train_llama3cu test_llama3cu

- name: Run default
run: ./test_llama3cu

unit-tests-gpu:
runs-on: ubicloud-gpu-standard-1-latest
Expand All @@ -126,3 +206,4 @@ jobs:

- name: Test Device<->File IO
run: cd dev/test && nvcc -o device_file_io device_file_io.cu && ./device_file_io

13 changes: 12 additions & 1 deletion Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -243,8 +243,13 @@ else
PFLAGS = -DENABLE_BF16
endif

# Optimizer precision settings, enable to allow BF16 for AdamW m/v state (also affects state file)
ifeq ($(OPTIMIZER_LOW_PRECISION), 1)
PFLAGS += -DOPTIMIZER_LOW_PRECISION
endif

# PHONY means these targets will always be executed
.PHONY: all train_gpt2 test_gpt2 train_gpt2cu test_gpt2cu train_gpt2fp32cu test_gpt2fp32cu profile_gpt2cu
.PHONY: all train_llama3cu test_llama3cu train_gpt2 test_gpt2 train_gpt2cu test_gpt2cu train_gpt2fp32cu test_gpt2fp32cu profile_gpt2cu

# Add targets
TARGETS = train_gpt2 test_gpt2
Expand Down Expand Up @@ -285,6 +290,12 @@ test_gpt2fp32cu: test_gpt2_fp32.cu
profile_gpt2cu: profile_gpt2.cu $(NVCC_CUDNN)
$(NVCC) $(NVCC_FLAGS) $(PFLAGS) -lineinfo $^ $(NVCC_LDFLAGS) $(NVCC_INCLUDES) $(NVCC_LDLIBS) $(CUDA_OUTPUT_FILE)

train_llama3cu: train_llama3.cu $(NVCC_CUDNN)
$(NVCC) $(NVCC_FLAGS) $(PFLAGS) $^ $(NVCC_LDFLAGS) $(NVCC_INCLUDES) $(NVCC_LDLIBS) $(CUDA_OUTPUT_FILE)

test_llama3cu: test_llama3.cu $(NVCC_CUDNN)
$(NVCC) $(NVCC_FLAGS) $(PFLAGS) $^ $(NVCC_LDFLAGS) $(NVCC_INCLUDES) $(NVCC_LDLIBS) $(CUDA_OUTPUT_FILE)

clean:
$(REMOVE_FILES) $(TARGETS)
$(REMOVE_BUILD_OBJECT_FILES)
13 changes: 13 additions & 0 deletions dev/cbridge/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
# cbridge

We'll use this directory for the PyTorch -> C bridge. So we have some PyTorch code and we'd like to have the equivalent C implementation (usually that one in turn serves as reference for the CUDA kernels later).

For starters we have RoPE. E.g. generate the reference with PyTorch and then match it in C:

```bash
python rope.py
gcc -o rope rope.c -lm
./rope
```

The .py file writes a `robe.bin` file with the intermediate tensors.
101 changes: 101 additions & 0 deletions dev/cbridge/rmsnorm.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,101 @@
"""
An RMSNorm PyTorch reference implementation.
This script then does forward/back and writes everything to file so we can
develop the CPU version, and eventually the GPU kernel as well.
"""

import math
import torch
import numpy as np
import torch.nn as nn
from torch.nn import functional as F

# -----------------------------------------------------------------------------

class RMSNorm(torch.nn.Module):
def __init__(self, dim: int, eps: float):
super().__init__()
self.eps = eps
self.weight = nn.Parameter(torch.ones(dim))

def _norm(self, x):
mean_sq = x.pow(2).mean(dim=-1, keepdim=True) + self.eps
rstd = torch.rsqrt(mean_sq)
norm = x * rstd
return norm

def forward(self, x):
output = self._norm(x.float()).type_as(x)
return output * self.weight

def rmsnorm_backward(x, w, dout, eps):
# recompute the rstd, norm (or we could cache it in the forward pass)
mean_sq = x.pow(2).mean(dim=-1, keepdim=True) + eps # (B, T, 1)
rstd = torch.rsqrt(mean_sq) # (B, T, 1)
norm = x * rstd # (B, T, C)
# gradients for weights
dw = (dout * norm).sum((0, 1)) # (C)
# gradients for input
dnorm = dout * w # (B, T, C)
dx = dnorm - norm * (dnorm * norm).mean(dim=-1, keepdim=True)
dx *= rstd
return dx, dw

# -----------------------------------------------------------------------------

# seed the rng
torch.manual_seed(42)

B = 4
T = 64
C = 256
eps = 1e-5

inp = torch.randn(B, T, C, dtype=torch.float32)
inp.requires_grad = True

# rmsnorm
m = RMSNorm(C, eps=eps)
out = m(inp)

# loss can just be a weighted sum, with some fixed weights
wei = torch.randn_like(out, dtype=torch.float32)
loss = (out * wei).sum()
loss.backward()

# let's now do the backward pass manually
# backprop starts with the output gradient, which is exactly wei because of the loss functions
dx, dw = rmsnorm_backward(inp, m.weight, wei, eps)
# let's assert that the gradients match
assert torch.allclose(dx, inp.grad, atol=1e-6)
assert torch.allclose(dw, m.weight.grad, atol=1e-6)
print("RMSNorm gradients match")
print("first 5 elements of dx comparison:")
print(dx.view(-1)[:5].tolist())
print(inp.grad.view(-1)[:5].tolist())
print("first 5 elements of dw comparison:")
print(dw.view(-1)[:5].tolist())
print(m.weight.grad.view(-1)[:5].tolist())
print("dx error:", (inp.grad.view(-1) - dx.view(-1)).abs().max().item())
print("dw error:", (m.weight.grad.view(-1) - dw.view(-1)).abs().max().item())

# save to .bin file so we can check correctness in C land
int_header = np.zeros(16, dtype=np.int32) # for ints
float_header = np.zeros(16, dtype=np.float32) # for floats
int_header[0] = 20240925 # magic number
int_header[1] = B
int_header[2] = T
int_header[3] = C
float_header[0] = eps

# write the hyperparameters, inputs, output, and input gradients to file
results_file = "rmsnorm.bin"
with open(results_file, "wb") as f:
f.write(int_header.tobytes()) # 16 int32
f.write(float_header.tobytes()) # 16 float32
f.write(inp.detach().cpu().numpy().tobytes()) # B * T * C
f.write(out.detach().cpu().numpy().tobytes()) # B * T * C
f.write(wei.detach().cpu().numpy().tobytes()) # B * T * C
f.write(inp.grad.detach().cpu().numpy().tobytes()) # B * T * C
f.write(m.weight.grad.detach().cpu().numpy().tobytes()) # C
print("Saved results to %s" % results_file)
Loading
Loading