Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -234,3 +234,11 @@ jobs:

- name: Build BF16
run: PRECISION=BF16 make test_llama3cu train_llama3cu

- name: Get cudnn
run: |
apt-get update && apt-get install -y git
git clone https://github.com/NVIDIA/cudnn-frontend.git

- name: Build cuDNN
run: PRECISION=BF16 USE_CUDNN=1 make test_llama3cu train_llama3cu
34 changes: 22 additions & 12 deletions .github/workflows/ci_gpu.yml
Original file line number Diff line number Diff line change
Expand Up @@ -115,29 +115,29 @@ jobs:
build-and-test-llama3:
name: Build and test LLama3.2 1B
runs-on: ubicloud-gpu-standard-1-latest
env:
HF_TOKEN: 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
container:
image: nvidia/cuda:12.6.3-cudnn-devel-ubuntu22.04
steps:
- name: Checkout code
uses: actions/checkout@v4
- run: echo "::add-mask::$HF_TOKEN"
- run: echo "::add-mask::$(echo us_xrYQGKBiJeqDMlTxkGhSgjelZKYbJHTgDY | tr 'A-Za-z' 'N-ZA-Mn-za-m')"

- name: Install OpenMP
run: sudo apt-get update && sudo apt-get install -y libomp-dev
run: apt-get update && apt-get install -y libomp-dev libopenmpi-dev python3-pip

- name: Install dependencies
run: pip install -r requirements.txt

- name: Run preprocessing
run: python dev/data/tinyshakespeare.py --model_desc llama-3
run: HF_TOKEN=$(echo us_xrYQGKBiJeqDMlTxkGhSgjelZKYbJHTgDY | tr 'A-Za-z' 'N-ZA-Mn-za-m') python3 dev/data/tinyshakespeare.py --model_desc llama-3

- name: Train model
# use the first 10 layers, so that everything fits into the 20GB of
# the A4000 Ada that we have in CI
run: python train_llama3.py --write_tensors 1 --dtype float32 --depth 10
run: HF_TOKEN=$(echo us_xrYQGKBiJeqDMlTxkGhSgjelZKYbJHTgDY | tr 'A-Za-z' 'N-ZA-Mn-za-m') python3 train_llama3.py --write_tensors 1 --dtype float32 --depth 10

- name: Build FP32 precision
run: PRECISION=FP32 make test_llama3cu
run: PRECISION=FP32 NO_MULTI_GPU=1 make test_llama3cu

- name: Run default
run: ./test_llama3cu
Expand All @@ -149,7 +149,7 @@ jobs:
run: ./test_llama3cu -r 2

- name: Build BF16 precision
run: PRECISION=BF16 make train_llama3cu test_llama3cu
run: PRECISION=BF16 NO_MULTI_GPU=1 make train_llama3cu test_llama3cu

- name: Run default (BF16)
run: ./test_llama3cu
Expand All @@ -166,15 +166,12 @@ jobs:
build-and-test-llama3-untied:
name: Build and test LLama3.2 1B with untie weights
runs-on: ubicloud-gpu-standard-1-latest
env:
HF_TOKEN: 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
steps:
- name: Checkout code
uses: actions/checkout@v4
- run: echo "::add-mask::$HF_TOKEN"

- name: Install OpenMP
run: sudo apt-get update && sudo apt-get install -y libomp-dev
run: sudo apt-get update && sudo apt-get install -y libomp-dev git

- name: Install dependencies
run: pip install -r requirements.txt
Expand All @@ -197,6 +194,19 @@ jobs:
- name: Run default
run: ./test_llama3cu

- name: Install cuDNN-frontend
run:
git clone https://github.com/NVIDIA/cudnn-frontend.git

- name: Build with cuDNN
run: USE_CUDNN=1 PRECISION=BF16 NO_MULTI_GPU=1 make train_llama3cu test_llama3cu

- name: Train model with cuDNN
run: ./train_llama3cu

- name: Execute testing program with cuDNN
run: ./test_llama3cu

unit-tests-gpu:
runs-on: ubicloud-gpu-standard-1-latest

Expand Down
2 changes: 1 addition & 1 deletion dev/data/fineweb.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@

def tokenize_llama(doc):
# tokenizes a single document and returns a numpy array of uint32 tokens
tokenizer = AutoTokenizer.from_pretrained("meta-llama/Meta-Llama-3.1-8B")
tokenizer = AutoTokenizer.from_pretrained("unsloth/Llama-3.2-1B")
encode = lambda s: tokenizer.encode(s, add_special_tokens=False, verbose=False, split_special_tokens=True)
eot = tokenizer.encode('')[0] # by default the tokenizer adds the EOT token (128000)
tokens = [eot] # the special <|endoftext|> token delimits all documents
Expand Down
2 changes: 1 addition & 1 deletion dev/data/tinyshakespeare.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ def tokenize(model_desc):
encode = lambda s: enc.encode_ordinary(s)
eot = enc._special_tokens['<|endoftext|>'] # end of text token
elif model_desc == "llama-3":
tokenizer = AutoTokenizer.from_pretrained("meta-llama/Meta-Llama-3.1-8B")
tokenizer = AutoTokenizer.from_pretrained("unsloth/Llama-3.2-1B")
encode = lambda s: tokenizer.encode(s, add_special_tokens=False, verbose=False, split_special_tokens=True)
eot = tokenizer.encode('')[0] # by default the tokenizer adds the EOT token (128000)
else:
Expand Down
2 changes: 1 addition & 1 deletion dev/data/tinystories.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,7 @@ def process_shard(shard_index, shard_filename, model_desc):
encode = lambda s: enc.encode_ordinary(s)
eot = enc._special_tokens['<|endoftext|>'] # end of text token
elif model_desc == "llama-3":
tokenizer = AutoTokenizer.from_pretrained("meta-llama/Meta-Llama-3.1-8B")
tokenizer = AutoTokenizer.from_pretrained("unsloth/Llama-3.2-1B")
encode = lambda s: tokenizer.encode(s, add_special_tokens=False, verbose=False, split_special_tokens=True)
eot = tokenizer.encode('')[0] # by default the tokenizer adds the EOT token (128000)
else:
Expand Down
94 changes: 48 additions & 46 deletions llmc/cudnn_att.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -53,15 +53,15 @@ enum UIDs {
};

// Need a cache because graph->build_operation_graph() is slow but everything else seems fast
using cache_type_fwd = std::map<std::tuple<int,int,int,int, int>, std::shared_ptr<fe::graph::Graph>>;
using cache_type_bwd = std::map<std::tuple<int,int,int,int>, std::shared_ptr<fe::graph::Graph>>;
using cache_type_fwd = std::map<std::tuple<int,int,int,int,int,int>, std::shared_ptr<fe::graph::Graph>>;
using cache_type_bwd = std::map<std::tuple<int,int,int,int,int>, std::shared_ptr<fe::graph::Graph>>;

// Loosely based on cuDNN frontend samples functions and massively simplified
auto lookup_cache_or_build_graph_fwd(int B,int H,int T,int HS, int is_inference_only) {
auto lookup_cache_or_build_graph_fwd(int B, int Hq, int Hkv, int T, int HS, int is_inference_only) {

static cache_type_fwd user_maintained_cache_fwd;

auto key = std::make_tuple(B, H, T, HS, is_inference_only);
auto key = std::make_tuple(B, Hq, Hkv, T, HS, is_inference_only);

auto it = user_maintained_cache_fwd.find(key);
if (it != user_maintained_cache_fwd.end()) {
Expand All @@ -74,18 +74,21 @@ auto lookup_cache_or_build_graph_fwd(int B,int H,int T,int HS, int is_inference_
.set_compute_data_type(fe::DataType_t::FLOAT);

// QKV is (B, T, 3, NH, HS) which cuDNN can handle directly without an external permute
// for (B, N, (NH + 2*(NH/replicate_factor)) * HS)
// (B, T, Hq + 2Hkv, HS)
int H = Hq + 2 * Hkv;
auto Q = graph->tensor(fe::graph::Tensor_attributes().set_name("Q")
.set_dim({B, H, T, HS})
.set_dim({B, Hq, T, HS})
.set_uid(Q_UID)
.set_stride({3 * H * HS * T, HS, 3 * H * HS, 1}));
.set_stride({H * HS * T, HS, H * HS, 1}));
auto K = graph->tensor(fe::graph::Tensor_attributes().set_name("K")
.set_dim({B, H, T, HS})
.set_dim({B, Hkv, T, HS})
.set_uid(K_UID)
.set_stride({3 * H * HS * T, HS, 3 * H * HS, 1}));
.set_stride({H * HS * T, HS, H * HS, 1}));
auto V = graph->tensor(fe::graph::Tensor_attributes().set_name("V")
.set_dim({B, H, T, HS})
.set_dim({B, Hkv, T, HS})
.set_uid(V_UID)
.set_stride({3 * H * HS * T, HS, 3 * H * HS, 1}));
.set_stride({H * HS * T, HS, H * HS, 1}));
auto attn_scale = graph->tensor(fe::graph::Tensor_attributes().set_name("attn_scale")
.set_dim({1, 1, 1, 1})
.set_stride({1, 1, 1, 1})
Expand All @@ -101,14 +104,14 @@ auto lookup_cache_or_build_graph_fwd(int B,int H,int T,int HS, int is_inference_
// Create the graph operation and get the output tensors back
auto [O, stats] = graph->sdpa(Q, K, V, sdpa_options);

// Output is (B, T, NH, HS) BF16/FP16 and stats for backward pass is (B, NH, T) FP32
O->set_output(true).set_dim({B, H, T, HS}).set_stride({H * HS * T, HS, H * HS, 1}).set_uid(O_UID);
// Output is (B, T, Hq, HS) BF16/FP16 and stats for backward pass is (B, Hq, T) FP32
O->set_output(true).set_dim({B, Hq, T, HS}).set_stride({Hq * HS * T, HS, Hq * HS, 1}).set_uid(O_UID);

assert(stats == nullptr || is_inference_only == false);
if (is_inference_only == false) {
stats->set_output(true).set_data_type(fe::DataType_t::FLOAT)
.set_dim({B, H, T, 1})
.set_stride({H * T, T, 1, 1})
.set_dim({B, Hq, T, 1})
.set_stride({Hq * T, T, 1, 1})
.set_uid(Stats_UID);
}

Expand All @@ -134,10 +137,10 @@ auto lookup_cache_or_build_graph_fwd(int B,int H,int T,int HS, int is_inference_
return graph;
}

auto lookup_cache_or_build_graph_bwd(int B, int NH, int T, int HS) {
auto lookup_cache_or_build_graph_bwd(int B, int Hq, int Hkv, int T, int HS) {
static cache_type_bwd user_maintained_cache_bwd;

auto key = std::make_tuple(B, NH, T, HS);
auto key = std::make_tuple(B, Hq, Hkv, T, HS);

auto it = user_maintained_cache_bwd.find(key);
if (it != user_maintained_cache_bwd.end()) {
Expand All @@ -151,31 +154,32 @@ auto lookup_cache_or_build_graph_bwd(int B, int NH, int T, int HS) {

// (B, N, 3, NH, HS)
// must come from inp (which means we also need to convert THAT to FP16)
int H = Hq + 2*Hkv;
auto Q = graph->tensor(fe::graph::Tensor_attributes().set_name("Q")
.set_dim({B, NH, T, HS})
.set_dim({B, Hq, T, HS})
.set_uid(Q_UID)
.set_stride({3 * NH * HS * T, HS, 3 * NH * HS, 1}));
.set_stride({H * HS * T, HS, H * HS, 1}));
auto K = graph->tensor(fe::graph::Tensor_attributes().set_name("K")
.set_dim({B, NH, T, HS})
.set_dim({B, Hkv, T, HS})
.set_uid(K_UID)
.set_stride({3 * NH * HS * T, HS, 3 * NH * HS, 1}));
.set_stride({H * HS * T, HS, H * HS, 1}));
auto V = graph->tensor(fe::graph::Tensor_attributes().set_name("V")
.set_dim({B, NH, T, HS})
.set_dim({B, Hkv, T, HS})
.set_uid(V_UID)
.set_stride({3 * NH * HS * T, HS, 3 * NH * HS, 1}));
.set_stride({H * HS * T, HS, H * HS, 1}));
auto O = graph->tensor(fe::graph::Tensor_attributes().set_name("O")
.set_dim({B, NH, T, HS})
.set_dim({B, Hq, T, HS})
.set_uid(O_UID)
.set_stride({NH * HS * T, HS, NH * HS, 1}));
.set_stride({Hq * HS * T, HS, Hq * HS, 1}));
auto dO = graph->tensor(fe::graph::Tensor_attributes().set_name("dO")
.set_dim({B, NH, T, HS})
.set_dim({B, Hq, T, HS})
.set_uid(dO_UID)
.set_stride({NH * HS * T, HS, NH * HS, 1}));
.set_stride({Hq * HS * T, HS, Hq * HS, 1}));

auto stats = graph->tensor(fe::graph::Tensor_attributes().set_name("stats")
.set_dim({B, NH, T, 1})
.set_dim({B, Hq, T, 1})
.set_uid(Stats_UID)
.set_stride({NH * T, T, 1, 1})
.set_stride({Hq * T, T, 1, 1})
.set_data_type(fe::DataType_t::FLOAT));
auto attn_scale = graph->tensor(fe::graph::Tensor_attributes().set_name("attn_scale")
.set_dim({1, 1, 1, 1})
Expand All @@ -193,9 +197,9 @@ auto lookup_cache_or_build_graph_bwd(int B, int NH, int T, int HS) {
// Create the graph operation and get the output tensors back
auto [dQ, dK, dV] = graph->sdpa_backward(Q, K, V, O, dO, stats, sdpa_backward_options);

dQ->set_output(true).set_dim({B, NH, T, HS}).set_stride({3 * NH * HS * T, HS, 3 * NH * HS, 1}).set_uid(dQ_UID);
dK->set_output(true).set_dim({B, NH, T, HS}).set_stride({3 * NH * HS * T, HS, 3 * NH * HS, 1}).set_uid(dK_UID);
dV->set_output(true).set_dim({B, NH, T, HS}).set_stride({3 * NH * HS * T, HS, 3 * NH * HS, 1}).set_uid(dV_UID);
dQ->set_output(true).set_dim({B, Hq, T, HS}).set_stride({H * HS * T, HS, H * HS, 1}).set_uid(dQ_UID);
dK->set_output(true).set_dim({B, Hkv, T, HS}).set_stride({H * HS * T, HS, H * HS, 1}).set_uid(dK_UID);
dV->set_output(true).set_dim({B, Hkv, T, HS}).set_stride({H * HS * T, HS, H * HS, 1}).set_uid(dV_UID);

checkCudnnFE(graph->validate());

Expand All @@ -219,23 +223,22 @@ auto lookup_cache_or_build_graph_bwd(int B, int NH, int T, int HS) {
return graph;
}

void attention_forward_cudnn(floatX* out, // output: (B, T, NH, HS)
float* stats, // output for backward pass: (B, NH, T)
floatX* inp, // input: (B, T, 3, NH, HS) QKV
int B, int T, int NH, int C, cudaStream_t stream) {
void attention_forward_cudnn(floatX* out, // output: (B, T, Hq, HS)
float* stats, // output for backward pass: (B, Hq, T)
floatX* inp, // input: (B, T, Hq + Hk + Hv, HS) QKV
int B, int T, int Hq, int Hkv, int HS, cudaStream_t stream) {
NVTX_RANGE_FN();
int HS = C / NH; // number of features per head
bool is_inference_only = (stats == nullptr);

cuDNNCheck(cudnnSetStream(cudnn_handle, stream));

// Get graph and tensors from cache (or generate it on first use)
auto graph = lookup_cache_or_build_graph_fwd(B, NH, T, HS, is_inference_only);
auto graph = lookup_cache_or_build_graph_fwd(B, Hq, Hkv, T, HS, is_inference_only);

// Prepare all the tensor pointers for executing the graph
void* devPtrQ = inp;
void* devPtrK = (inp + C);
void* devPtrV = (inp + 2 * C);
void* devPtrK = (inp + Hq * HS);
void* devPtrV = (inp + (Hq + Hkv) * HS);
float attn_scale_cpu = 1.0 / sqrtf(HS);
void* devPtrO = out;

Expand All @@ -255,25 +258,24 @@ void attention_forward_cudnn(floatX* out, // output: (B, T, NH, HS)

void attention_backward_cudnn(floatX* dqkvr, // output
floatX* dout, floatX* qkvr, floatX* o, float* stats, // inputs
int B, int T, int NH, int C, cudaStream_t stream) {
int B, int T, int Hq, int Hkv, int HS, cudaStream_t stream) {
NVTX_RANGE_FN();
int HS = C / NH; // number of features per head

// Get graph and tensors from cache (or generate it on first use)
auto graph = lookup_cache_or_build_graph_bwd(B, NH, T, HS);
auto graph = lookup_cache_or_build_graph_bwd(B, Hq, Hkv, T, HS);

// Prepare all the tensor pointers for executing the graph
void* devPtrQ = qkvr;
void* devPtrK = (qkvr + NH * HS);
void* devPtrV = (qkvr + 2 * NH * HS);
void* devPtrK = (qkvr + Hq * HS);
void* devPtrV = (qkvr + (Hq + Hkv) * HS);
void* devPtrO = o;
void* devPtrdO = dout;
void* devPtrStats = stats;
float attn_scale_cpu = 1.0 / sqrtf(HS);

void* devPtrdQ = dqkvr;
void* devPtrdK = (dqkvr + NH * HS);
void* devPtrdV = (dqkvr + 2 * NH * HS);
void* devPtrdK = (dqkvr + Hq * HS);
void* devPtrdV = (dqkvr + (Hq + Hkv) * HS);

// Build variant pack that links each tensor to its data pointer
std::unordered_map<int64_t, void*> variant_pack = {
Expand Down
10 changes: 5 additions & 5 deletions llmc/cudnn_att.h
Original file line number Diff line number Diff line change
Expand Up @@ -9,13 +9,13 @@ cuDNN (flash) attention
// forward declarations of functions defined in cudnn_att.cpp
void create_cudnn();
void destroy_cudnn();
void attention_forward_cudnn(floatX* out, // output: (B, T, NH, HS)
float* stats, // output for backward pass: (B, NH, T)
floatX* inp, // input: (B, T, 3, NH, HS) QKV
int B, int T, int NH, int C, cudaStream_t stream);
void attention_forward_cudnn(floatX* out, // output: (B, T, Nq, HS)
float* stats, // output for backward pass: (B, Hq, T)
floatX* inp, // input: (B, T, Hq + 2Hkv, HS) QKV
int B, int T, int Hq, int Hkv, int HS, cudaStream_t stream);

void attention_backward_cudnn(floatX* dqkvr, // output
floatX* dout, floatX* qkvr, floatX* o, float* stats, // inputs
int B, int T, int NH, int C, cudaStream_t stream);
int B, int T, int Hq, int Hkv, int HS, cudaStream_t stream);

#endif // CUDNN_ATT_H
Loading