diff --git a/Makefile b/Makefile index 6fa511db4..b9c174dd0 100644 --- a/Makefile +++ b/Makefile @@ -269,7 +269,7 @@ $(NVCC_CUDNN): llmc/cudnn_att.cpp $(NVCC) -c $(NVCC_FLAGS) $(PFLAGS) $^ $(NVCC_INCLUDES) -o $@ train_gpt2cu: train_gpt2.cu $(NVCC_CUDNN) - $(NVCC) $(NVCC_FLAGS) $(PFLAGS) $^ $(NVCC_LDFLAGS) $(NVCC_INCLUDES) $(NVCC_LDLIBS) $(CUDA_OUTPUT_FILE) + $(NVCC) $(NVCC_FLAGS) $(PFLAGS) -lineinfo $^ $(NVCC_LDFLAGS) $(NVCC_INCLUDES) $(NVCC_LDLIBS) $(CUDA_OUTPUT_FILE) train_gpt2fp32cu: train_gpt2_fp32.cu $(NVCC) $(NVCC_FLAGS) $^ $(NVCC_LDFLAGS) $(NVCC_INCLUDES) $(NVCC_LDLIBS) $(CUDA_OUTPUT_FILE) diff --git a/llmc/adamw.cuh b/llmc/adamw.cuh index 4453576ee..340bc492a 100644 --- a/llmc/adamw.cuh +++ b/llmc/adamw.cuh @@ -15,84 +15,151 @@ __device__ float lerp(float start, float end, float weight) { return fma(weight, end, fma(-weight, start, start)); } -template -__device__ void adamw_update(Tp* params_memory, float* master_params_memory, Tg* grads_memory, float* m_memory, float* v_memory, size_t num_parameters, - float learning_rate, float beta1, float beta2, float beta1_correction, float beta2_correction, float eps, float weight_decay, - float grad_scale, unsigned int seed) { - int idx = blockIdx.x * blockDim.x + threadIdx.x; - if (idx >= num_parameters) { return; } // guard - - // get the gradient, m, and v for this parameter - float grad = grad_scale * (float)grads_memory[idx]; - float m = m_memory[idx]; - float v = v_memory[idx]; - // update the first moment (momentum) - m = lerp(grad, m, beta1); - m_memory[idx] = m; - // update the second moment (RMSprop) - v = lerp(grad * grad, v, beta2); - v_memory[idx] = v; - m /= beta1_correction; // m_hat - v /= beta2_correction; // v_hat - // fetch the old value of this parameter as a float, from either source - float old_param = (master_params_memory != NULL) ? master_params_memory[idx] : (float)params_memory[idx]; - // update this parameter - float param = old_param - (learning_rate * (m / (sqrtf(v) + eps) + weight_decay * old_param)); - // update our low precision version of the parameters using stochastic rounding - // this will be used in the next forward pass - stochastic_rounding(param, ¶ms_memory[idx], seed); - // write the full, float version of the param into our master copy, if we maintain one - // this will be used in the next update - if (master_params_memory != NULL) { master_params_memory[idx] = param; } -} +template +__device__ size_t adamw_update_part(TensorGPU param_tensor, + size_t idx, size_t current_start, size_t current_end, size_t stride, unsigned int seed, unsigned int shard_idx, + TensorGPU grad_tensor, TensorGPU master_tensor, TensorGPU opt_m_tensor, TensorGPU opt_v_tensor, + float learning_rate, float beta1, float beta2, float beta1_correction, float beta2_correction, + float eps, float wd, float grad_scale, int t) { + auto out_master128 = new_tensor128(master_tensor, true); + auto out_opt_m128 = new_tensor128(opt_m_tensor, true); + auto out_opt_v128 = new_tensor128(opt_v_tensor, true); + auto out_param128 = new_tensor128(param_tensor); -template -__global__ void adamw_kernel3(Tp* params_memory, float* master_params_memory, Tg* grads_memory, float* m_memory, float* v_memory, size_t num_parameters, - ptrdiff_t w_stride, ptrdiff_t g_stride, ptrdiff_t s_stride, - float learning_rate, float beta1, float beta2, float beta1_correction, float beta2_correction, float eps, float weight_decay, - float grad_scale, unsigned int seed) { - adamw_update(params_memory + blockIdx.y * w_stride, - master_params_memory ? master_params_memory + blockIdx.y * s_stride : NULL, - grads_memory + blockIdx.y * g_stride, - m_memory + blockIdx.y * s_stride, - v_memory + blockIdx.y * s_stride, - num_parameters, learning_rate, beta1, beta2, beta1_correction, beta2_correction, eps, weight_decay, grad_scale, - seed - ); -} + __syncthreads(); // todo - this should improve memory locality + while (idx < current_end) { + unsigned int random = get_random_noise(seed, idx); -template -__global__ void init_from_master_kernel(Tp* params_memory, float* master_params_memory, size_t num_parameters, - ptrdiff_t w_stride, ptrdiff_t s_stride, unsigned int seed) { - size_t idx = blockIdx.x * blockDim.x + threadIdx.x; - if (idx >= num_parameters) { return; } - params_memory += blockIdx.y * w_stride; // adjust for layer offset - master_params_memory += blockIdx.y * s_stride; - stochastic_rounding(master_params_memory[idx], ¶ms_memory[idx], seed); -} + tensor128 param128; + tensor128 grad128; + tensor128 opt_m128; + tensor128 opt_v128; + tensor128 master128; + int next_idx[TT::NUM_TYPES_PARAM] = {0}; + int current_idx[TT::NUM_TYPES_PARAM] = {0}; + + // todo - assuming either DPP or ZeRO 1 now (sharded optimizer/master, unsharded gradients/parameters) + // offset is 32-bit (checked <= elements in add_tensor_spec) + unsigned int offset = idx - current_start; + unsigned int unsharded_offset = offset + shard_idx * opt_v_tensor.num_elements; + + // this implementation has a stride causing sparse reads/writes and bank conflicts for non-FP8 + // todo - compare performance with a version that uses 128-bit for FP32, 64-bit for BF16, 32-bit for FP8 (probably much faster) + #pragma unroll + for (int i = 0; i < 16; i += 4, offset += 4, unsharded_offset += 4) { + if (current_idx[PARAMETER] == 0) param128 = load_tensor128(param_tensor, unsharded_offset); + if (current_idx[PARAMETER_GRAD] == 0) grad128 = load_tensor128(grad_tensor, unsharded_offset, false, true); + if (current_idx[PARAMETER_OPT_M] == 0) opt_m128 = load_tensor128(opt_m_tensor, offset, false,true); + if (current_idx[PARAMETER_OPT_V] == 0) opt_v128 = load_tensor128(opt_v_tensor, offset, false, true); + if (current_idx[PARAMETER_MASTER] == 0 && use_master_weights) master128 = load_tensor128(master_tensor, offset, false, true); + + for (int k = 0; k < 4; k++) { + float grad = grad128.get(current_idx[PARAMETER_GRAD] + k); + float m = opt_m128.get(current_idx[PARAMETER_OPT_M] + k); + float v = opt_v128.get(current_idx[PARAMETER_OPT_V] + k); + + m = lerp(grad, m, beta1); + v = lerp(grad * grad, v, beta2); + out_opt_m128.set(current_idx[PARAMETER_OPT_M] + k, m); + out_opt_v128.set(current_idx[PARAMETER_OPT_V] + k, v); + m /= beta1_correction; + v /= beta2_correction; + + float old_param; + if constexpr (use_master_weights) { + old_param = master128.get(current_idx[PARAMETER_MASTER] + k); + } else { + old_param = param128.get(current_idx[PARAMETER] + k); + } + + float param = old_param - (learning_rate * (m / (sqrtf(v) + eps) + wd * old_param)); + if constexpr (use_master_weights) { + out_master128.set(current_idx[PARAMETER_MASTER] + k, param); + } + out_param128.set_stochastic(current_idx[PARAMETER] + k, param, random); + } + next_idx[PARAMETER] = (i + 4) % (16 / sizeof(Tparam)); + next_idx[PARAMETER_GRAD] = (i + 4) % (16 / sizeof(Tgrad)); + next_idx[PARAMETER_OPT_M] = (i + 4) % (16 / sizeof(Tm)); + next_idx[PARAMETER_OPT_V] = (i + 4) % (16 / sizeof(Tv)); + next_idx[PARAMETER_MASTER] = (i + 4) % (16 / sizeof(Tmaster)); -template -void adamw_update(Tp* params_memory, float* master_params_memory, Tg* grads_memory, float* m_memory, float* v_memory, size_t num_parameters, - ptrdiff_t w_stride, ptrdiff_t g_stride, ptrdiff_t s_stride, int num_slices, float learning_rate, float beta1, float beta2, int t, float eps, float weight_decay, - float grad_scale, unsigned int seed, cudaStream_t stream) { - // AdamW update - int block_size = 512; - int num_blocks = CEIL_DIV(num_parameters, block_size); - float beta1_correction = 1.0f - powf(beta1, t); - float beta2_correction = 1.0f - powf(beta2, t); - adamw_kernel3<<>>(params_memory, master_params_memory, grads_memory, - m_memory, v_memory, num_parameters, w_stride, g_stride, s_stride, - learning_rate, beta1, beta2, beta1_correction, beta2_correction, eps, weight_decay, - grad_scale, seed); - cudaCheck(cudaGetLastError()); + if (next_idx[PARAMETER] == 0) out_param128.store(unsharded_offset - current_idx[PARAMETER]); + if (next_idx[PARAMETER_OPT_M] == 0) out_opt_m128.store(offset - current_idx[PARAMETER_OPT_M]); + if (next_idx[PARAMETER_OPT_V] == 0) out_opt_v128.store(offset - current_idx[PARAMETER_OPT_V]); + if constexpr (use_master_weights) { + if (next_idx[PARAMETER_MASTER] == 0) out_master128.store(offset - current_idx[PARAMETER_MASTER]); + } + + for (int n = 0; n < TT::NUM_TYPES_PARAM; n++) { + current_idx[n] = next_idx[n]; + } + } + idx += stride; + } + out_param128.update_absmax(1); + return idx; } -template -void init_from_master(Tp* params_memory, float* master_params_memory, size_t num_parameters, - ptrdiff_t w_stride, ptrdiff_t s_stride, int num_slices, unsigned int seed, cudaStream_t stream) { - int block_size = 512; // must match block size of adamw_update so that RNG also matches - int num_blocks = CEIL_DIV(num_parameters, block_size); - init_from_master_kernel<<>> - (params_memory, master_params_memory, num_parameters, w_stride, s_stride, seed); - cudaCheck(cudaGetLastError()); +template +__global__ void adamw_update_everything(int num_params_tensors, int start_tensor, int last_tensor, unsigned int seed , unsigned int shard_idx, + float learning_rate, float beta1, float beta2, float beta1_correction, float beta2_correction, + float eps, float weight_decay, float grad_scale, int t) { + // ... + constexpr size_t block_size = 64; + constexpr size_t iteration_size = 16; // todo - this causes sparsit and bank clashes for FP32/BF16 loads/stores + size_t idx = (blockIdx.x * block_size * iteration_size) + (threadIdx.x * iteration_size); + unsigned int stride = gridDim.x * blockDim.x * iteration_size; + + int opt_m_spec_id = 2 * num_params_tensors; + int last_opt_m_id = opt_m_spec_id + last_tensor; // opt_m is sharded with ZeRO 1 so use it as reference + opt_m_spec_id += start_tensor - 1; // -1 to compensate for the increment at the start of the loop below + + while (true) { + size_t current_end; + do { + opt_m_spec_id++; + if (opt_m_spec_id > last_opt_m_id) return; // done! + + // on A100+ we can prefetch 256B (32 values) into the L2, on older GPUs just use a regular load + #if __CUDA_ARCH__ < 800 + current_end = tensor_end_element_ptr[opt_m_spec_id]; + #else + asm("ld.global.L2::256B.u64 {%0}, [%1];" : "=l"(current_end) : "l"(tensor_end_element_ptr + opt_m_spec_id)); + #endif + } while (idx >= current_end); + + int spec_id = opt_m_spec_id - 2 * num_params_tensors; + size_t current_start = tensor_specs_ptr[opt_m_spec_id].start_element; + + TensorSpec param_spec = tensor_specs_ptr[spec_id]; + TensorGPU grad_tensor = tensor_specs_ptr[spec_id + 1*num_params_tensors]; + TensorGPU opt_m_tensor = tensor_specs_ptr[spec_id + 2*num_params_tensors]; + TensorGPU opt_v_tensor = tensor_specs_ptr[spec_id + 3*num_params_tensors]; + TensorGPU master_tensor = use_master_weights ? tensor_specs_ptr[spec_id + 4*num_params_tensors] : opt_m_tensor; + + float wd = (param_spec.tensor_flags & TENSOR_2D) ? weight_decay : 0.0f; + + if (param_spec.data_type == DType::FP32) { + idx = adamw_update_part((TensorGPU)param_spec, + idx, current_start, current_end, stride, seed, shard_idx, + grad_tensor, master_tensor, opt_m_tensor, opt_v_tensor, + learning_rate, beta1, beta2, beta1_correction, beta2_correction, + eps, wd, grad_scale, t); + } else if (param_spec.data_type == DType::BF16) { + idx = adamw_update_part((TensorGPU<__nv_bfloat16>)param_spec, + idx, current_start, current_end, stride, seed, shard_idx, + grad_tensor, master_tensor, opt_m_tensor, opt_v_tensor, + learning_rate, beta1, beta2, beta1_correction, beta2_correction, + eps, wd, grad_scale, t); + } else if (param_spec.data_type == DType::FP8E4M3) { + idx = adamw_update_part((TensorGPU<__nv_fp8_e4m3>)param_spec, + idx, current_start, current_end, stride, seed, shard_idx, + grad_tensor, master_tensor, opt_m_tensor, opt_v_tensor, + learning_rate, beta1, beta2, beta1_correction, beta2_correction, + eps, wd, grad_scale, t); + } else { + assert(false); // TODO (no FP16 to avoid compile time increase but trivial to add here) + } + } } diff --git a/llmc/attention.cuh b/llmc/attention.cuh index f6294a213..72cd9c545 100644 --- a/llmc/attention.cuh +++ b/llmc/attention.cuh @@ -12,11 +12,11 @@ Attention, as a fallback when we do not use the Flash Attention from cuDNN // inputs floatX, outputs FP32 (for current FP32-only activation path for this WIP) __global__ void permute_kernel(floatX* q, floatX* k, floatX* v, - const floatX* inp, + tensorX inp, int B, int N, int NH, int d) { // okay so now, this kernel wants Q,K,V to all be of shape (B, NH, N, d) // but instead, we have a single tensor QKV (inp) of shape (B, N, 3, NH, d) - int idx = blockIdx.x * blockDim.x + threadIdx.x; + int idx = (blockIdx.x * blockDim.x + threadIdx.x) * inp.num_per_128(); if (idx >= B * NH * N * d) { return; } // Q[b][nh_][n][d_] = inp[b][n][0][nh_][d_] @@ -27,15 +27,21 @@ __global__ void permute_kernel(floatX* q, floatX* k, floatX* v, int n = rest / d; int d_ = rest % d; int inp_idx = (b * N * 3 * NH * d) + (n * 3 * NH * d) + (0 * NH * d) + (nh_ * d) + d_; - q[idx] = __ldcs(&inp[inp_idx]); - k[idx] = __ldcs(&inp[inp_idx + NH * d]); - v[idx] = __ldcs(&inp[inp_idx + 2 * (NH * d)]); + + auto inp128_q = load_tensor128(inp, inp_idx, true); + auto inp128_k = load_tensor128(inp, inp_idx + NH * d, true); + auto inp128_v = load_tensor128(inp, inp_idx + 2 * (NH * d), true); + for (int i = 0; i < inp.num_per_128(); i++) { + q[idx+i] = inp128_q.get(i); + k[idx+i] = inp128_k.get(i); + v[idx+i] = inp128_v.get(i); + } } -__global__ void permute_kernel_backward(floatX* dinp, +__global__ void permute_kernel_backward(tensorX dinp, const floatX* dq, const floatX* dk, const floatX* dv, int B, int N, int NH, int d) { - int idx = blockIdx.x * blockDim.x + threadIdx.x; + int idx = (blockIdx.x * blockDim.x + threadIdx.x) * dinp.num_per_128(); if (idx >= B * NH * N * d) { return; } int b = idx / (NH * N * d); @@ -49,12 +55,29 @@ __global__ void permute_kernel_backward(floatX* dinp, dinp[inp_idx] = dq[idx]; dinp[inp_idx + NH * d] = dk[idx]; dinp[inp_idx + 2 * (NH * d)] = dv[idx]; + + auto dinp128_q = new_tensor128(dinp); + auto dinp128_k = new_tensor128(dinp); + auto dinp128_v = new_tensor128(dinp); + for (int i = 0; i < dinp.num_per_128(); i++) { + dinp128_q.set(i, dq[idx+i]); + dinp128_k.set(i, dk[idx+i]); + dinp128_v.set(i, dv[idx+i]); + + // to allow us to update the absmax only once for the q vector + dinp128_q.add_value_stats(dk[idx+i], dinp128_k.get128()[i]); + dinp128_q.add_value_stats(dv[idx+i], dinp128_v.get128()[i]); + } + dinp128_q.store(inp_idx); + dinp128_k.store(inp_idx + NH * d); + dinp128_v.store(inp_idx + 2 * (NH * d)); + dinp128_q.update_absmax(1); } -__global__ void unpermute_kernel(floatX* inp, floatX *out, int B, int N, int NH, int d) { +__global__ void unpermute_kernel(tensorX out, floatX* inp, int B, int N, int NH, int d) { // out has shape (B, nh, N, d) but we need to unpermute it to (B, N, nh, d) - int idx = (blockIdx.x * blockDim.x + threadIdx.x); + int idx = (blockIdx.x * blockDim.x + threadIdx.x) * out.num_per_128(); // out[b][n][nh_][d_] <- inp[b][nh_][n][d_] if (idx >= B * NH * N * d) { return; } @@ -65,11 +88,16 @@ __global__ void unpermute_kernel(floatX* inp, floatX *out, int B, int N, int NH, int n = rest / d; int d_ = rest % d; int other_idx = (b * NH * N * d) + (n * NH * d) + (nh_ * d) + d_; - out[other_idx] = __ldcs(&inp[idx]); + auto out128 = new_tensor128(out); + for (int i = 0; i < out.num_per_128(); i++) { + out128.set(i, __ldcs(&inp[idx + i])); + } + out128.store(other_idx); + out128.update_absmax(1); } -__global__ void unpermute_kernel_backward(floatX* dinp, const floatX *dout, int B, int N, int NH, int d) { - int idx = blockIdx.x * blockDim.x + threadIdx.x; +__global__ void unpermute_kernel_backward(floatX* dout_permuted, tensorX dout, int B, int N, int NH, int d) { + int idx = (blockIdx.x * blockDim.x + threadIdx.x) * dout.num_per_128(); if (idx >= B * NH * N * d) { return; } int b = idx / (NH * N * d); @@ -79,10 +107,13 @@ __global__ void unpermute_kernel_backward(floatX* dinp, const floatX *dout, int int n = rest / d; int d_ = rest % d; int other_idx = (b * NH * N * d) + (n * NH * d) + (nh_ * d) + d_; - dinp[idx] = (floatX)dout[other_idx]; + auto dout128 = load_tensor128(dout, other_idx); + for (int k = 0; k < dout128.elements; k++) { + dout_permuted[idx+k] = (floatX)dout128.get(k); + } } -__global__ void softmax_forward_kernel5(floatX* out, float inv_temperature, const floatX* inp, int N, int T) { +__global__ void softmax_forward_kernel5(floatX* out, float inv_temperature, floatX* inp, int N, int T) { // inp, out shape: (N, T, T), where N = B * NH // fuses the multiplication by scale inside attention // directly autoregressive, so we only compute the lower triangular part @@ -149,7 +180,7 @@ __global__ void softmax_forward_kernel5(floatX* out, float inv_temperature, cons } } -__global__ void softmax_autoregressive_backward_inplace_kernel(floatX* datt, const floatX* att, +__global__ void softmax_autoregressive_backward_inplace_kernel(floatX* datt, floatX* att, int B, int T, int C, float scale) { constexpr const int BlockSize = 256; constexpr int T_per_block = 4; @@ -192,9 +223,9 @@ __global__ void softmax_autoregressive_backward_inplace_kernel(floatX* datt, con // ---------------------------------------------------------------------------- // kernel launchers -void attention_forward(floatX* out, floatX* qkvr, floatX* att, - floatX* inp, - int B, int T, int C, int NH, cudaStream_t stream) { +void attention_forward(tensorX out, floatX* qkvr, floatX* att, + tensorX inp, + int B, int T, int C, int NH, cudaStream_t stream=main_stream) { NVTX_RANGE_FN(); // Note: `inp` is not needed for backward pass, so we re-use it as a scratch buffer. // Its contents will be overwritten by this function. @@ -211,11 +242,11 @@ void attention_forward(floatX* out, floatX* qkvr, floatX* att, k = qkvr + 1 * B * T * C; v = qkvr + 2 * B * T * C; int total_threads = B * NH * T * HS; - int num_blocks = CEIL_DIV(total_threads, block_size); + int num_blocks = CEIL_DIV(total_threads, block_size * inp.num_per_128()); permute_kernel<<>>(q, k, v, inp, B, T, NH, HS); floatX* preatt = inp; // reuse inp as scratch buffer - matmul_cublaslt(preatt, k, q, nullptr, T, T, HS, stream, true, false, B * NH, T * HS, T * HS, T * T); + matmul_cublaslt(tensorX::from(preatt), tensorX::from(k), tensorX::from(q), null_tensorX, T, T, HS, stream, true, false, B * NH, T * HS, T * HS, T * T); // multiply all elements of preatt elementwise by scale float scale = 1.f / sqrtf(HS); @@ -225,27 +256,29 @@ void attention_forward(floatX* out, floatX* qkvr, floatX* att, // new approach: first cuBLAS another batched matmul floatX* vaccum = inp; // y = att @ v # (B, nh, T, T) @ (B, nh, T, hs) -> (B, nh, T, hs) - matmul_cublaslt(vaccum, v, att, nullptr, HS, T, T, stream, false, false, B * NH, T * HS, T * T, T * HS); + matmul_cublaslt(tensorX::from(vaccum), tensorX::from(v), tensorX::from(att), null_tensorX, HS, T, T, stream, false, false, B * NH, T * HS, T * T, T * HS); // now unpermute // y = y.transpose(1, 2).contiguous().view(B, T, C) # re-assemble all head outputs side by side - num_blocks = CEIL_DIV(B * T * C, block_size); - unpermute_kernel<<>>(vaccum, out, B, T, NH, HS); + num_blocks = CEIL_DIV(B * T * C, block_size * out.num_per_128()); + unpermute_kernel<<>>(out, vaccum, B, T, NH, HS); cudaCheck(cudaGetLastError()); } // the sequence of transformations in this compound op is: // inp (B,T,3C) -> qkvr (B,T,3C) -> preatt (B,NH,T,T) -> att (B,NH,T,T) -> vaccum (B,T,C) -> out (B,T,C) -void attention_backward(floatX* dinp, floatX* dqkvr, floatX* datt, floatX* scratch, - const floatX* dout, - const floatX* qkvr, const floatX* att, - int B, int T, int C, int NH, cudaStream_t stream) { +void attention_backward(tensorX dinp, floatX* dqkvr, floatX* datt, + tensorX dout, tensorX qkvr, floatX* att, + int B, int T, int C, int NH, cudaStream_t stream=main_stream) { NVTX_RANGE_FN(); const int block_size = 256; const int HS = C / NH; // head size + // now reusing dinp as scratch buffer (free before the final output and it's the right size) + floatX* scratch = dinp.data_ptr; + // unpack convenience pointers into q, k, v - const floatX *q, *k, *v; + floatX *q, *k, *v; q = qkvr + 0 * B * T * C; k = qkvr + 1 * B * T * C; v = qkvr + 2 * B * T * C; @@ -255,22 +288,24 @@ void attention_backward(floatX* dinp, floatX* dqkvr, floatX* datt, floatX* scrat dv = dqkvr + 2 * B * T * C; // backward through the unpermute operation - int num_blocks = CEIL_DIV(B * T * C, block_size); + int num_blocks = CEIL_DIV(B * T * C, block_size * dout.num_per_128()); unpermute_kernel_backward<<>>(scratch, dout, B, T, NH, HS); + cudaCheck(cudaGetLastError()); // backward into datt - matmul_cublaslt(datt, v, scratch, nullptr, T, T, HS, stream, true, false, B * NH, T * HS, T * HS, T * T); + matmul_cublaslt(tensorX::from(datt), tensorX::from(v), tensorX::from(scratch), null_tensorX, T, T, HS, stream, true, false, B * NH, T * HS, T * HS, T * T); // backward into dv - matmul_cublaslt(dv, scratch, att, nullptr, HS, T, T, stream, false, true, B * NH, T * HS, T * T, T * HS); + matmul_cublaslt(tensorX::from(dv), tensorX::from(scratch), tensorX::from(att), null_tensorX, HS, T, T, stream, false, true, B * NH, T * HS, T * T, T * HS); const float scale = 1.0f / sqrtf((float)HS); // backward into preatt. this is an in-place operation; datt turns into dpreatt here softmax_autoregressive_backward_inplace_kernel<<>>(datt, att, B, T, C, scale); - const floatX* dpreatt = datt; + cudaCheck(cudaGetLastError()); + floatX* dpreatt = datt; // backward into q - matmul_cublaslt(dq, k, dpreatt, nullptr, HS, T, T, stream, false, false, B * NH, T * HS, T * T, T * HS); + matmul_cublaslt(tensorX::from(dq), tensorX::from(k), tensorX::from(dpreatt), null_tensorX, HS, T, T, stream, false, false, B * NH, T * HS, T * T, T * HS); // backward into k - matmul_cublaslt(dk, q, dpreatt, nullptr, HS, T, T, stream, false, true, B * NH, T * HS, T * T, T * HS); + matmul_cublaslt(tensorX::from(dk), tensorX::from(q), tensorX::from(dpreatt), null_tensorX, HS, T, T, stream, false, true, B * NH, T * HS, T * T, T * HS); // backward into inp - num_blocks = CEIL_DIV(B * NH * T * HS, block_size); + num_blocks = CEIL_DIV(B * NH * T * HS, block_size * dinp.num_per_128()); permute_kernel_backward<<>>(dinp, dq, dk, dv, B, T, NH, HS); cudaCheck(cudaGetLastError()); } diff --git a/llmc/copy_and_fp8.h b/llmc/copy_and_fp8.h new file mode 100644 index 000000000..a784c0bd7 --- /dev/null +++ b/llmc/copy_and_fp8.h @@ -0,0 +1,157 @@ +/* +Helpers for FP8 including copy and transpose with format conversion, and absmax +See /dev/cuda/advanced_copy_transpose.cu for more information and options +*/ +#ifndef FP8_HELPERS_CUH +#define FP8_HELPERS_CUH + +#include +#include +#include "cuda_common.h" +#include "cuda_utils.cuh" + +// todo - tune these for performance (but should be close to optimal already) +#define TRANSPOSE_TILE_SIZE 64UL + +// ---------------------------------------------------------------------------- +// elementwise functions which can be applied as part of the copy/transpose +// for elementwise kernels that require metadata (e.g. layernorm forward with known mean/std), +// we could maybe store it in constant buffers rather than in yet-another-function-parameter... +using elementwise_func_t = float (*) (float); +__device__ float nothing_elementwise(float x) { + return x; +} +__device__ float gelu_forward_elementwise(float x) { + float cube = 0.044715f * x * x * x; + + float tanh_out; + float tanh_arg = sqrtf(2.0f / M_PI) * (x + cube); + asm ("tanh.approx.f32 %0,%1;" : "=f"(tanh_out) : "f"(tanh_arg)); + + // the following uses FMUL+FMA instead of FMUL+FADD+FMUL for "0.5f * x * (1.0f + tanh_out)" + float half_x = 0.5f * x; + return half_x * tanh_out + half_x; +} + +// ---------------------------------------------------------------------------- +// CUDA kernels + +// Advanced copy with optional format conversion, absmax, scaling and elementwise operation +template +__global__ void copy_advanced_kernel(TensorGPU out, TensorGPU in) { + constexpr size_t vec_size = 16 / ((sizeof(Tin) >= sizeof(Tout)) ? sizeof(Tin) : sizeof(Tout)); + size_t adjusted_blockidx = reversed_order ? (gridDim.x - blockIdx.x - 1) : blockIdx.x; + size_t idx = (adjusted_blockidx * blockDim.x + threadIdx.x) * vec_size; + if (idx >= out.num_elements) { return; } + + auto inp128 = load_tensor128(in, idx, true, disable_scaling); + auto out128 = new_tensor128(out, disable_scaling); + for (int k = 0; k < vec_size; k++) { + float out_fp32 = elementwise_func(inp128.get(k)); + out128.set(k, out_fp32); + } + out128.template store_same_length(idx); + out128.update_absmax(1); +} + +template +__global__ void transpose_simple_kernel(T1* __restrict__ transposed, const T1* __restrict__ input) +{ + constexpr size_t elements = 16 / sizeof(T1); + __shared__ T1 tile[TILE_DIM][TILE_DIM]; + int width = gridDim.x * TILE_DIM; + int height = gridDim.y * TILE_DIM; + + int x = blockIdx.x * TILE_DIM + threadIdx.x * elements; + int y = blockIdx.y * TILE_DIM + threadIdx.y; + + #pragma unroll + for (int j = 0; j < TILE_DIM; j += BLOCK_ROWS) { + Packed128 in128 = load128cs(input + x + (y+j)*width); + size_t tile_offset = (threadIdx.x * elements) + (threadIdx.y+j)*TILE_DIM; + store128(&tile[0][0] + tile_offset, in128); + } + __syncthreads(); + + // x/y for final write to global memory + x = blockIdx.y * TILE_DIM + threadIdx.x * elements; + y = blockIdx.x * TILE_DIM + threadIdx.y; + + #pragma unroll + for (int j = 0; j < TILE_DIM; j += BLOCK_ROWS) { + Packed128 out128; + #pragma unroll + for (int k = 0; k < elements; k++) { + // these are tiny 8-bit loads with loads of bank conflicts for FP8 + // extremely hard to avoid and not a bottleneck when everything else is well optimised + out128[k] = tile[k + threadIdx.x * elements][threadIdx.y + j]; + } + store128(transposed + x + (y+j)*height, out128); + } +} + +// only calculate absmax of the input tensor (non-fused) +template +__global__ void update_absmax_kernel(TensorGPU inp) { + size_t idx = ((blockIdx.x * blockDim.x) + threadIdx.x) * inp.num_per_128(); + auto max128 = new_tensor128(inp); + if (idx < inp.num_elements) { + auto inp128 = load_tensor128(inp, idx, disable_scaling); + for(int k = 0; k < inp.num_per_128(); ++k) { + float value = inp128.get(k); + max128.add_value_stats(value); + } + } + max128.update_absmax(threadIdx.x, blockDim.x, true, true); +} + +// ---------------------------------------------------------------------------- + +template +void copy_advanced(TensorGPU out, TensorGPU in, cudaStream_t stream=0, const size_t block_size=512) { + size_t N = out.num_elements; + size_t fewest_elements = min(Packed128::size, Packed128::size); + assert((N % fewest_elements) == 0); + + const dim3 grid_size(CEIL_DIV(N, block_size * fewest_elements)); + copy_advanced_kernel<<>>(out, in); + cudaCheck(cudaGetLastError()); +} + +template +void transpose_simple(TensorGPU transposed, TensorGPU input, size_t w, size_t h, cudaStream_t stream=0, size_t block_size=128) { + assert((w % TRANSPOSE_TILE_SIZE) == 0 && (h % TRANSPOSE_TILE_SIZE) == 0); + cudaCheck(cudaGetLastError()); + + size_t block_size_x = (TRANSPOSE_TILE_SIZE * sizeof(T1)) / 16; + size_t block_size_y = min(TRANSPOSE_TILE_SIZE, block_size / block_size_x); + dim3 grid_size(w / TRANSPOSE_TILE_SIZE, h / (TRANSPOSE_TILE_SIZE)); + dim3 block_size_dim(block_size_x, block_size_y, 1); + + switch (block_size_y) { + case 64: transpose_simple_kernel<64, TRANSPOSE_TILE_SIZE><<>>((T1*)transposed, (T1*)input); break; + case 32: transpose_simple_kernel<32, TRANSPOSE_TILE_SIZE><<>>((T1*)transposed, (T1*)input); break; + case 16: transpose_simple_kernel<16, TRANSPOSE_TILE_SIZE><<>>((T1*)transposed, (T1*)input); break; + default: printf("Invalid block size (might be easy to add): %lu\n", block_size_y); exit(1); + } + cudaCheck(cudaGetLastError()); +} + +template +void update_absmax(TensorGPU inp, bool memset_absmax=true, cudaStream_t stream=main_stream) { + size_t N = inp.num_elements; + if (N == 0 || inp.absmax_ptr == NULL) { return; } + assert(N % inp.num_per_128() == 0); + + size_t block_size = 512; + const dim3 grid_size(CEIL_DIV(N, block_size * Packed128::size)); + if (memset_absmax) { + cudaMemset(inp.absmax_ptr, 0, sizeof(unsigned int)); + } + update_absmax_kernel<<>>(inp); + cudaCheck(cudaGetLastError()); +} + +#endif \ No newline at end of file diff --git a/llmc/cuda_common.h b/llmc/cuda_common.h index 006ad3010..a8b4607a3 100644 --- a/llmc/cuda_common.h +++ b/llmc/cuda_common.h @@ -15,6 +15,7 @@ Common utilities for CUDA code. #include #include #include +#include #include "utils.h" @@ -26,16 +27,24 @@ Common utilities for CUDA code. // but it is actually created and instantiated in the main program file extern cudaDeviceProp deviceProp; +// Main stream used by default for all CUDA operations +extern cudaStream_t main_stream; + // WarpSize is not a compile time constant // Defining here like this possibly allows the compiler to optimize better #define WARP_SIZE 32U -// try to make sure that 2 blocks fit on A100/H100 to maximise latency tolerance +// optimise the number of blocks that fit to maximise latency tolerance // this needs to be defines rather than queried to be used for __launch_bounds__ -#if __CUDA_ARCH__ == 800 || __CUDA_ARCH__ >= 900 +#if __CUDA_ARCH__ >= 900 || __CUDA_ARCH__ == 800 || __CUDA_ARCH__ <= 700 #define MAX_1024_THREADS_BLOCKS 2 +#define MAX_THREADS 2048 // H100/A100/V100/Pascal/Maxwell(/Blackwell?) +#elif __CUDA_ARCH__ == 750 +#define MAX_1024_THREADS_BLOCKS 1 +#define MAX_THREADS 1024 // Turing #else #define MAX_1024_THREADS_BLOCKS 1 +#define MAX_THREADS 1536 // Consumer Ampere & Ada Lovelace #endif // convenience macro for calculating grid/block dimensions for kernels @@ -43,7 +52,7 @@ extern cudaDeviceProp deviceProp; // short-cuts for compile-time boolean values that can be used as function arguments constexpr std::bool_constant True; -constexpr std::bool_constant False; +constexpr std::bool_constant False; // ---------------------------------------------------------------------------- // Error checking @@ -82,13 +91,28 @@ enum PrecisionMode { #if defined(ENABLE_FP32) typedef float floatX; #define PRECISION_MODE PRECISION_FP32 +#define DTYPE_FLOATX DType::FP32 // use fp16 (note: this may require gradient scaler, currently not implemented!) #elif defined(ENABLE_FP16) typedef half floatX; #define PRECISION_MODE PRECISION_FP16 +#define DTYPE_FLOATX DType::FP16 #else // Default to bfloat16 typedef __nv_bfloat16 floatX; #define PRECISION_MODE PRECISION_BF16 +#define DTYPE_FLOATX DType::BF16 +#endif + +#if defined(ENABLE_FP8) +typedef __nv_fp8_e4m3 float8; +typedef __nv_fp8_e5m2 float8e5; +#define DTYPE_FP8E4 DType::FP8E4M3 +#define DTYPE_FP8E5 DType::FP8E5M2 +#else +typedef floatX float8; +typedef floatX float8e5; +#define DTYPE_FP8E4 DTYPE_FLOATX +#define DTYPE_FP8E5 DTYPE_FLOATX #endif // ---------------------------------------------------------------------------- @@ -127,7 +151,7 @@ class NvtxRange { // Utilities to Read & Write between CUDA memory <-> files // copy num_bytes from device pointer src into file dest, using double buffering running on the given stream. -inline void device_to_file(FILE* dest, void* src, size_t num_bytes, size_t buffer_size, cudaStream_t stream) { +inline void device_to_file(FILE* dest, void* src, size_t num_bytes, size_t buffer_size, cudaStream_t stream=main_stream) { // allocate pinned buffer for faster, async transfer char* buffer_space; cudaCheck(cudaMallocHost(&buffer_space, 2*buffer_size)); @@ -166,7 +190,7 @@ inline void device_to_file(FILE* dest, void* src, size_t num_bytes, size_t buffe } // copy num_bytes from file src into device pointer dest, using double buffering running on the given stream. -inline void file_to_device(void* dest, FILE* src, size_t num_bytes, size_t buffer_size, cudaStream_t stream) { +inline void file_to_device(void* dest, FILE* src, size_t num_bytes, size_t buffer_size, cudaStream_t stream=main_stream) { // allocate pinned buffer for faster, async transfer // from the docs (https://developer.download.nvidia.com/compute/DevZone/docs/html/C/doc/html/group__CUDART__HIGHLEVEL_ge439496de696b166ba457dab5dd4f356.html) // WC memory is a good option for buffers that will be written by the CPU and read by the device via mapped pinned memory or host->device transfers. diff --git a/llmc/cuda_utils.cuh b/llmc/cuda_utils.cuh index 0ce728ee1..02e16d0eb 100644 --- a/llmc/cuda_utils.cuh +++ b/llmc/cuda_utils.cuh @@ -2,7 +2,6 @@ #ifndef CUDA_UTILS_CUH #define CUDA_UTILS_CUH - #include "cuda_common.h" // ---------------------------------------------------------------------------- @@ -51,30 +50,58 @@ struct alignas(16) Packed128 { // load a Packed128 from an aligned memory address template -__device__ Packed128 load128(const ElementType* address) { +__device__ Packed128 load128(const ElementType* __restrict__ address) { return Packed128{*reinterpret_cast(address)}; } // load a Packed128 from an aligned memory address with streaming cache hint template -__device__ Packed128 load128cs(const ElementType* address) { +__device__ Packed128 load128cs(const ElementType* __restrict__ address) { return Packed128{__ldcs(reinterpret_cast(address))}; } // store a Packed128 to an aligned memory address template -__device__ void store128(ElementType* target, Packed128 value) { +__device__ void store128(ElementType* __restrict__ target, Packed128 value) { *reinterpret_cast(target) = value.get_bits(); } // store a Packed128 to an aligned memory address with streaming cache hint template -__device__ void store128cs(ElementType* target, Packed128 value) { +__device__ void store128cs(ElementType* __restrict__ target, Packed128 value) { __stcs(reinterpret_cast(target), value.get_bits()); } // store a Packed128 to an aligned memory address while caching in L2 but bypassing L1 template -__device__ void store128cg(ElementType* target, Packed128 value) { +__device__ void store128cg(ElementType* __restrict__ target, Packed128 value) { __stcg(reinterpret_cast(target), value.get_bits()); } +// This helper is for when we want to copy from e.g. FP32 to BF16 +// so if want to load a f128 of 4 elements, and write those 4 elements to memory as 64-bit +// not needed in the case of loads, the compiler will automatically optimise away unused reads +template +__device__ void store128_same_length(ElementType* target, Packed128 value) { + int4 bits = value.get_bits(); + switch (sizeof(OriginalType) / sizeof(ElementType)) { + case 0: *reinterpret_cast(target) = bits; break; // smaller + case 1: *reinterpret_cast(target) = bits; break; // same size + case 2: *reinterpret_cast(target) = make_int2(bits.x, bits.y); break; + case 4: *reinterpret_cast(target) = bits.x; break; + default: break; //assert(false); + } +} + +// with streaming cache hint (low persistence in L1/L2 caches) +template +__device__ void store128_same_length_cs(ElementType* target, Packed128 value) { + int4 bits = value.get_bits(); + switch (sizeof(OriginalType) / sizeof(ElementType)) { + case 0: __stcs(reinterpret_cast(target), bits); break; // smaller + case 1: __stcs(reinterpret_cast(target), bits); break; // same size + case 2: __stcs(reinterpret_cast(target), make_int2(bits.x, bits.y)); break; + case 4: __stcs(reinterpret_cast(target), bits.x); break; + default: break; //assert(false); + } +} + // short-form typedefs typedef Packed128 f128; typedef Packed128 x128; @@ -84,7 +111,7 @@ typedef Packed128 x128; // enumerator to indentify the datatype of a tensor. enum class DType : uint8_t { - FP32, FP16, BF16 + FP32, FP16, BF16, FP8E4M3, FP8E5M2 }; // Given a datatype enum, returns the underlying number of bytes @@ -97,6 +124,10 @@ size_t sizeof_dtype(DType type) { return sizeof(half); case DType::BF16: return sizeof(nv_bfloat16); + case DType::FP8E4M3: + return sizeof(__nv_fp8_e4m3); + case DType::FP8E5M2: + return sizeof(__nv_fp8_e5m2); default: // handle or get compiler warning fprintf(stderr, "Unknown datatype\n"); exit(EXIT_FAILURE); @@ -106,38 +137,125 @@ size_t sizeof_dtype(DType type) { DType dtype_of(float* f) { return DType::FP32; } DType dtype_of(nv_bfloat16 * f) { return DType::BF16; } DType dtype_of(half * f) { return DType::FP16; } - - +DType dtype_of(__nv_fp8_e4m3 * f) { return DType::FP8E4M3; } +DType dtype_of(__nv_fp8_e5m2 * f) { return DType::FP8E5M2; } // ---------------------------------------------------------------------------- -// Copy, cast functions - -// device functions and the kernel to cast data between types -template -__device__ Td cast_value(Ts val); +// Random Number Generation used in Stochastic Rounding (defined here as used by TensorGPU) -template<> -__device__ float cast_value(float val) { - return val; +// SquirrelNoise5 - Squirrel's Raw Noise utilities (version 5) +// This gives us a random number from threadIdx/blockIdx + a single seed for the entire GPU +// todo - possibly overkill and we don't need such high quality random numbers? (tbd) +// http://eiserloh.net/noise/SquirrelNoise5.hpp +__device__ __host__ unsigned int SquirrelNoise5(unsigned int positionX, unsigned int seed) { + constexpr unsigned int SQ5_BIT_NOISE1 = 0xd2a80a3f; // 11010010101010000000101000111111 + constexpr unsigned int SQ5_BIT_NOISE2 = 0xa884f197; // 10101000100001001111000110010111 + constexpr unsigned int SQ5_BIT_NOISE3 = 0x6C736F4B; // 01101100011100110110111101001011 + constexpr unsigned int SQ5_BIT_NOISE4 = 0xB79F3ABB; // 10110111100111110011101010111011 + constexpr unsigned int SQ5_BIT_NOISE5 = 0x1b56c4f5; // 00011011010101101100010011110101 + unsigned int mangledBits = positionX; + mangledBits *= SQ5_BIT_NOISE1; + mangledBits += seed; + mangledBits ^= (mangledBits >> 9); + mangledBits += SQ5_BIT_NOISE2; + mangledBits ^= (mangledBits >> 11); + mangledBits *= SQ5_BIT_NOISE3; + mangledBits ^= (mangledBits >> 13); + mangledBits += SQ5_BIT_NOISE4; + mangledBits ^= (mangledBits >> 15); + mangledBits *= SQ5_BIT_NOISE5; + mangledBits ^= (mangledBits >> 17); + return mangledBits; } -template<> -__device__ float cast_value(half val) { - return __half2float(val); +// rely on default values of 0 being optimised away for 1D/2D/3D (shorter than original code) +__device__ __host__ unsigned int get_random_noise(unsigned int seed, unsigned int x, + unsigned int y=0, unsigned int z=0, unsigned int t=0) { + constexpr unsigned int PRIME1 = 198491317u; // Large prime number with non-boring bits + constexpr unsigned int PRIME2 = 6542989u; // Large prime number with distinct and non-boring bits + constexpr unsigned int PRIME3 = 357239u; // Large prime number with distinct and non-boring bits + return SquirrelNoise5(x + (PRIME1 * y) + (PRIME2 * z) + (PRIME3 * t), seed); } -template<> -__device__ float cast_value(__nv_bfloat16 val) { - return __bfloat162float(val); +// stochastic rounding (typicalling using Squirel Noise above to go from a seed to a random number) +// new algorithm that calculates distance from rounded up/down values to correctly handle denorms +// (didn't matter with BF16 because denorms are so tiny they're irrelevant, unlike in FP8/FP16) +template +__device__ void stochastic_rounding(float in, Ti &out, unsigned int random, float prob_offset=0.0f) { + if constexpr (std::is_same::value) { + out = in; + return; + } + + // prob_offset allows rounding towards gradient more of the time (one paper recommends that) + // e.g. +0.3f ==> 65% chance up, 35% chance down + float threshold_percentage = ((float)random / (float)0xFFFFFFFF) - prob_offset; + + Ti rounded_down = (Ti)0.0f, rounded_up = (Ti)0.0f; + if constexpr (std::is_same::value) { + rounded_down = __float2half_rd(in); + rounded_up = __float2half_ru(in); + } else if constexpr (std::is_same::value) { + rounded_down = __float2bfloat16_rd(in); + rounded_up = __float2bfloat16_ru(in); + } else if constexpr (std::is_same::value) { + // CUDA doesn't have round down/up instructions for FP8 (in SW or HW) so we do it ourselves + // ARM-Intel-NVIDIA style FP8 E4M3 (different for AMD-Graphcore-Qualcomm format!) + // tried this approach to avoid bug with fake_fp8 (didn't help), keeping it for now... + // todo: check whether it properly matches the bit shifting method (do exhaustive testing!) + float low = in; + float high = in; + + if (fabsf(in) < 0.0156f) { + low -= 0.000975f; + high += 0.000975f; + } else { + if (in > 0.0f) { + low *= (15.5f / 16.0f); + high *= (8.5f / 8.0f); + } else { + low *= (8.5f / 8.0f); + high *= (15.5f / 16.0f); + } + } + rounded_up = (__nv_fp8_e4m3)high; + rounded_down = (__nv_fp8_e4m3)low; + } else { + assert(false); + } + + float diff = (float)rounded_up - (float)rounded_down; + float lerp = (in - (float)rounded_down) / diff; // division by 0 is OK as it means (up == down) anyway + out = (lerp > threshold_percentage) ? rounded_up : rounded_down; } -template -__global__ void copy_and_cast_kernel(Td* dst, const Ts* src, size_t n, ptrdiff_t stride_dst, ptrdiff_t stride_src) { - int idx = blockIdx.x * blockDim.x + threadIdx.x; - // need to try grid stride looping for more perf later - if (idx < n) { - dst[idx + stride_dst * blockIdx.y] = cast_value(src[idx + stride_src * blockIdx.y]); +// ---------------------------------------------------------------------------- +__device__ float fake_low_precision(bool faking, float input, float scale, float descale, bool mode_e5, bool stochastic=false) { +#ifdef FAKE_LOW_PRECISION + unsigned int random_number; + if (faking && scale != 1.0f) { + assert(scale == 1.0f/descale || descale == 1.0f/scale || scale == 1.0f); + if (stochastic) { + unsigned int clock, laneid; + asm volatile("mov.u32 %0, %%clock;" : "=r"(clock)); + asm volatile("mov.u32 %0, %%laneid;" : "=r"(laneid)); + random_number = get_random_noise(clock, laneid, blockIdx.x * blockDim.x); + } + + if (mode_e5) { + __nv_fp8_e5m2 value_fp8 = __nv_fp8_e5m2(input * scale); + return ((float)value_fp8) * descale; + } else { + __nv_fp8_e4m3 value_fp8 = __nv_fp8_e4m3(input * scale); + if (stochastic) { + // BUGGED(?) - spent 6+ hours debugging and I genuinely suspect a compiler bug *sigh* + stochastic_rounding(input * scale, value_fp8, random_number); + } + return ((float)value_fp8) * descale; + } } +#endif + return input; } // ---------------------------------------------------------------------------- @@ -157,7 +275,7 @@ __device__ inline float warpReduceMax(float val) { } return val; } -// requires all 32 threads in the warp to be active, but should work for any block size +// requires all 32 threads in the warp to be active, but should work for any 1D(!) block size // uses non-dynamic shared memory so every call increases shared memory requirements by 128 bytes // the fact it's unique shared memory allows us to avoid an extra __syncthreads() call at the end // but if called inside a loop, the shared memory will be implicitly reused, so set final_sync to 1 @@ -205,59 +323,4 @@ void global_sum_deterministic(float* result, const Float* values, int count, cud cudaCheck(cudaGetLastError()); } -// ---------------------------------------------------------------------------- -// Random Number Generation used in Stochastic Rounding - -// SquirrelNoise5 - Squirrel's Raw Noise utilities (version 5) -// This gives us a random number from threadIdx/blockIdx + a single seed for the entire GPU -// todo - possibly overkill and we don't need such high quality random numbers? (tbd) -// http://eiserloh.net/noise/SquirrelNoise5.hpp -__device__ __host__ constexpr unsigned int SquirrelNoise5(unsigned int positionX, unsigned int seed) -{ - constexpr unsigned int SQ5_BIT_NOISE1 = 0xd2a80a3f; // 11010010101010000000101000111111 - constexpr unsigned int SQ5_BIT_NOISE2 = 0xa884f197; // 10101000100001001111000110010111 - constexpr unsigned int SQ5_BIT_NOISE3 = 0x6C736F4B; // 01101100011100110110111101001011 - constexpr unsigned int SQ5_BIT_NOISE4 = 0xB79F3ABB; // 10110111100111110011101010111011 - constexpr unsigned int SQ5_BIT_NOISE5 = 0x1b56c4f5; // 00011011010101101100010011110101 - unsigned int mangledBits = positionX; - mangledBits *= SQ5_BIT_NOISE1; - mangledBits += seed; - mangledBits ^= (mangledBits >> 9); - mangledBits += SQ5_BIT_NOISE2; - mangledBits ^= (mangledBits >> 11); - mangledBits *= SQ5_BIT_NOISE3; - mangledBits ^= (mangledBits >> 13); - mangledBits += SQ5_BIT_NOISE4; - mangledBits ^= (mangledBits >> 15); - mangledBits *= SQ5_BIT_NOISE5; - mangledBits ^= (mangledBits >> 17); - return mangledBits; -} -__device__ __host__ constexpr unsigned int Get2dNoiseUint(int indexX, int indexY, unsigned int seed) -{ - constexpr unsigned int PRIME_NUMBER = 198491317u; // Large prime number with non-boring bits - unsigned int x = static_cast(indexX); - unsigned int y = static_cast(indexY); - - return SquirrelNoise5(x + (PRIME_NUMBER * y), seed); -} - -// stochastic rounding built on top of Squirel Noise above (with seed updated per step via xorshift) -__device__ __forceinline__ void stochastic_rounding(float in, __nv_bfloat16 *out, unsigned int seed) { - // todo - is this stochastic rounding *too good*? can we cut any corners? - // makes sure each thread gets a different random number - unsigned int random = Get2dNoiseUint(threadIdx.x, blockIdx.x * blockDim.x + blockIdx.y, seed); - unsigned int threshold = random & 0xFFFF; - unsigned int float_bits = __float_as_uint(in); - unsigned int rounded_bits = float_bits & 0x0000FFFF; - float_bits = (rounded_bits > threshold) ? (float_bits | 0xFFFF) : (float_bits & ~0xFFFF); - *out = __float2bfloat16_rn(__uint_as_float(float_bits)); -} -__device__ __forceinline__ void stochastic_rounding(float in, half *out, unsigned int random) { - *out = (float)in; // todo - implement this... -} -__device__ __forceinline__ void stochastic_rounding(float in, float *out, unsigned int random) { - *out = in; // dummy function for when floatX is float (FP32 mode) -} - #endif \ No newline at end of file diff --git a/llmc/cudnn_att.cpp b/llmc/cudnn_att.cpp index 0330abe20..3d2f8af4d 100644 --- a/llmc/cudnn_att.cpp +++ b/llmc/cudnn_att.cpp @@ -222,7 +222,7 @@ auto lookup_cache_or_build_graph_bwd(int B, int NH, int T, int HS) { void attention_forward_cudnn(floatX* out, // output: (B, T, NH, HS) float* stats, // output for backward pass: (B, NH, T) floatX* inp, // input: (B, T, 3, NH, HS) QKV - int B, int T, int NH, int C, cudaStream_t stream) { + int B, int T, int NH, int C, cudaStream_t stream=main_stream) { NVTX_RANGE_FN(); int HS = C / NH; // number of features per head bool is_inference_only = (stats == nullptr); @@ -255,7 +255,7 @@ void attention_forward_cudnn(floatX* out, // output: (B, T, NH, HS) void attention_backward_cudnn(floatX* dqkvr, // output floatX* dout, floatX* qkvr, floatX* o, float* stats, // inputs - int B, int T, int NH, int C, cudaStream_t stream) { + int B, int T, int NH, int C, cudaStream_t stream=main_stream) { NVTX_RANGE_FN(); int HS = C / NH; // number of features per head diff --git a/llmc/encoder.cuh b/llmc/encoder.cuh index 3aa63e175..980165ca4 100644 --- a/llmc/encoder.cuh +++ b/llmc/encoder.cuh @@ -16,8 +16,8 @@ In the backward pass, the gradients flow to both, handled by different kernels // ---------------------------------------------------------------------------- // CUDA kernels -__global__ void encoder_forward_kernel3(floatX* out, - const int* inp, const floatX* wte, const floatX* wpe, +__global__ void encoder_forward_kernel3(tensorX out, + const int* inp, const tensorX wte, const tensorX wpe, int B, int T, int C) { int idx = (blockIdx.x * blockDim.x + threadIdx.x) * x128::size; int N = B * T * C; @@ -27,25 +27,22 @@ __global__ void encoder_forward_kernel3(floatX* out, int b = bt / T; int t = bt % T; int c = idx % C; - int ix = inp[b * T + t]; - floatX* out_btc = out + b * T * C + t * C + c; - const floatX* wte_ix = wte + ix * C + c; - const floatX* wpe_tc = wpe + t * C + c; + auto out128 = new_tensor128(out); + auto wte128 = load_tensor128(wte, ix * C + c); + auto wpe128 = load_tensor128(wpe, t * C + c); - x128 packed_out; - x128 wte128 = load128cs(wte_ix); - x128 wpe128 = load128cs(wpe_tc); for (int k = 0; k < x128::size; k++) { - packed_out[k] = (floatX)((float)wte128[k] + (float)wpe128[k]); + out128.set(k, wte128.get(k) + wpe128.get(k)); } - store128(out_btc, packed_out); + out128.store(b * T * C + t * C + c); + out128.update_absmax(1); } template -__global__ void wte_backward_kernel(floatX* dwte, - const int4* bucket_info, const int* workload_indices, const floatX* dout, const int* inp, +__global__ void wte_backward_kernel(tensorX dwte, + const int4* bucket_info, const int* workload_indices, const tensorX dout, const int* inp, unsigned int seed, int B, int T, int C) { // In order to be deterministic, we preprocess the inputs on the cpu into "buckets" // Each bucket corresponds to (WARP_SIZE * x128::size) channels for a single vocabulary token @@ -75,11 +72,9 @@ __global__ void wte_backward_kernel(floatX* dwte, for(int item = warp_id; item < bucket_size; item += BLOCK_SIZE/WARP_SIZE) { int bt = workload_indices[bucket_start_idx + item]; - - const floatX* dout_btc = dout + bt * C + c; - x128 packed_inp1 = load128cs(dout_btc); - for (int k = 0; k < packed_inp1.size; k++) { - accum[k] += (float)packed_inp1[k]; + auto dout128 = load_tensor128(dout, bt * C + c, true); + for (int k = 0; k < dout128.elements; k++) { + accum[k] += dout128.get(k); } } @@ -92,8 +87,7 @@ __global__ void wte_backward_kernel(floatX* dwte, } // Read dwte for warp 0 even if other warps are not finished yet to maximise latency tolerance - floatX* dwte_ix = dwte + bucket_ix * C + c; - x128 packed_in_out = load128(dwte_ix); + auto dwte128 = load_tensor128(dwte, bucket_ix * C + c, false, true); // note: threads which have returned are considered synchronised by CUDA so no risk of deadlock __syncthreads(); @@ -105,19 +99,19 @@ __global__ void wte_backward_kernel(floatX* dwte, } } - // Add the result to dwte and write back to global memory (read-modify-write) + // add the result to dwte and write back to global memory (read-modify-write) + // we use stochastic rounding to go from FP32 to BF16/whatever (the seed is deterministic) + // reusing same random value but shifting based on the index in set_stochastic ("good enough") + unsigned int random = get_random_noise(seed, threadIdx.x, bucket); for (unsigned int k = 0; k < x128::size; k++) { - // We use stochastic rounding to go from FP32 to BF16 - // The seed is deterministic and unique for each parameter to guarantee we have determinism AND - // to avoid **potential** issues with positionX int SquirrelNoise5 argument overflowing which is UB - // and that somehow messing the quality of random numbers - stochastic_rounding(accum[k] + (float)packed_in_out[k], &packed_in_out[k], seed + bucket * WARP_SIZE + threadIdx.x + k); + dwte128.set_stochastic(k, accum[k] + dwte128.get(k), random); } - store128(dwte_ix, packed_in_out); + dwte128.store(bucket_ix * C + c); + dwte128.update_absmax(1); } -__global__ void wpe_backward_kernel(floatX* dwpe, - const floatX* dout, const int* inp, +__global__ void wpe_backward_kernel(tensorX dwpe, + const tensorX dout, const int* inp, int B, int T, int C, unsigned int seed) { // Each thread handles x128::size "channel positions", e.g. 256 per warp for BF16 // For gpt2-124M BF16, C=768 and T=1024, so 3 warps per channel and 3072 warps in total @@ -133,30 +127,31 @@ __global__ void wpe_backward_kernel(floatX* dwpe, float accum[x128::size] = {0.0f}; for (int b = 0; b < B; b++) { - x128 packed_dout = load128cs(dout + (b * T * C) + (t * C) + c); // will never be read again + auto dout128 = load_tensor128(dout, b * T * C + t * C + c, true); for (int k = 0; k < x128::size; k++) { - accum[k] += (float)packed_dout[k]; + accum[k] += dout128.get(k); } } - floatX* dwpe_tc = dwpe + (t * C) + c; - x128 packed_dwpe = load128(dwpe_tc); + auto dwpe128 = load_tensor128(dwpe, t * C + c); + unsigned int random = get_random_noise(seed, t, c); for (unsigned int k = 0; k < x128::size; k++) { // We use stochastic rounding to go from FP32 to BF16 // The seed is deterministic and unique for each parameter to guarantee we have determinism AND // to avoid **potential** issues with positionX int SquirrelNoise5 argument overflowing which is UB // and that somehow messing the quality of random numbers - stochastic_rounding(accum[k] + (float)packed_dwpe[k], &packed_dwpe[k], seed + idx + k); + dwpe128.set_stochastic(k, accum[k] + dwpe128.get(k), random); } - store128(dwpe_tc, packed_dwpe); + dwpe128.store(t * C + c); + dwpe128.update_absmax(1); } // ---------------------------------------------------------------------------- // kernel launchers -void encoder_forward(floatX* out, - const int* inp, const floatX* wte, const floatX* wpe, - int B, int T, int C, cudaStream_t stream) { +void encoder_forward(tensorX out, + const int* inp, const tensorX wte, const tensorX wpe, + int B, int T, int C, cudaStream_t stream=main_stream) { NVTX_RANGE_FN(); const int block_size = 256; const int N = B * T * C; @@ -166,10 +161,10 @@ void encoder_forward(floatX* out, } // Fully deterministic (see comments in wte_backward_kernel and wpe_backward_kernel for more details) -void encoder_backward(floatX* dwte, floatX* dwpe, floatX* scratch, // gpu outputs & scratch +void encoder_backward(tensorX dwte, tensorX dwpe, tensorX scratch, // gpu outputs & scratch int* workload_indices, int4* bucket_info, // cpu scratch buffers - const floatX* dout, const int* inp, const int* inputs_cpu, // cpu/gpu inputs - int B, int T, int C, unsigned int seed, cudaStream_t stream) { + const tensorX dout, const int* inp, const int* inputs_cpu, // cpu/gpu inputs + int B, int T, int C, unsigned int seed, cudaStream_t stream=main_stream) { NVTX_RANGE_FN(); // Launch wpe kernel first (so it runs on the GPU in parallel with the CPU pre-processing for wte) @@ -222,7 +217,7 @@ void encoder_backward(floatX* dwte, floatX* dwpe, floatX* scratch, // gpu output // Step 3: Copy data from host to device (async until the last one to avoid synchronising CPU/GPU twice) // todo - could use CUDA events (even without streams) to avoid CPU/GPU synchronisation completely - int4* d_bucket_info = (int4*)scratch; + int4* d_bucket_info = (int4*)scratch.data_ptr; int* d_workload_indices = (int*)(scratch + B*T*num_c_groups * sizeof(int4)); cudaCheck(cudaMemcpyAsync(d_bucket_info, bucket_info, num_buckets * sizeof(int4), cudaMemcpyHostToDevice, stream)); cudaCheck(cudaMemcpyAsync(d_workload_indices, workload_indices, total_items * sizeof(int), cudaMemcpyHostToDevice, stream)); diff --git a/llmc/fused_classifier.cuh b/llmc/fused_classifier.cuh index 4837d4cb0..8a0ac8962 100644 --- a/llmc/fused_classifier.cuh +++ b/llmc/fused_classifier.cuh @@ -16,23 +16,22 @@ struct SoftmaxParams { float Offset; }; -__device__ SoftmaxParams prepare_softmax_blockwide3(int64_t idx, const floatX* inp, int V, int P) { +__device__ SoftmaxParams prepare_softmax_blockwide3(int64_t idx, tensorX inp, int V, int P) { // same but not float4 // one row of inp, i.e. inp[idx, :] of shape (V,) - - const floatX* x = inp + idx * P; + int elements = inp.num_per_128(); float thread_maxval = -INFINITY; float thread_sumval = 0.0f; - int i = (V+x128::size-1)/x128::size + threadIdx.x - blockDim.x; + int i = (V+elements-1)/elements + threadIdx.x - blockDim.x; // special-case loop to handle the unaligned elements at the end of the array // this lets us skip the bounds check in the main loop below, which improves performance - while ((i+1)*x128::size > V) { - for(int k = 0; k < x128::size; ++k) { - if (i*x128::size+k >= V) { + while ((i+1)*elements > V) { + for(int k = 0; k < elements; ++k) { + if (i*elements+k >= V) { break; // bounds checking against real V (rather than padded P) } - float v = (float)x[i*x128::size+k]; + float v = inp.get_scalar(idx * P + i * elements + k); float old_maxval = thread_maxval; thread_maxval = fmaxf(thread_maxval, v); thread_sumval *= expf((old_maxval - thread_maxval)); @@ -43,9 +42,9 @@ __device__ SoftmaxParams prepare_softmax_blockwide3(int64_t idx, const floatX* i // main loop for the bulk of the iterations (no bounds checking required!) for (; i >= 0; i -= blockDim.x) { - x128 packed_x = load128(x + i * x128::size); // load and keep in cache until fused_classifier loop - for(int k = 0; k < x128::size; ++k) { - float v = (float)packed_x[k]; + auto inp128 = load_tensor128(inp, idx * P + i * elements); + for(int k = 0; k < elements; ++k) { + float v = inp128.get(k); float old_maxval = thread_maxval; thread_maxval = fmaxf(thread_maxval, v); thread_sumval *= expf((old_maxval - thread_maxval)); @@ -67,13 +66,14 @@ __device__ SoftmaxParams prepare_softmax_blockwide3(int64_t idx, const floatX* i // split both loops in "multiple-of-x128-size" and "bounds-checked remainder" parts template __global__ void __launch_bounds__(1024, MAX_1024_THREADS_BLOCKS) - fused_classifier_kernel5(floatX* logits, float* losses, floatX* probs, + fused_classifier_kernel5(tensorX dlogits, tensorX logits, float* losses, floatX* probs, const float dloss, const int* targets, - int B, int T, int V, int P, std::bool_constant) { + int V, int P, std::bool_constant) { // note: idx is small enough that it easily fits into 32 bit; - // by making it a long here, we ensure that any offsets calculated with it (e.g., idx * P) + // by making it size_t here, we ensure that any offsets calculated with it (e.g., idx * P) // are done is 64 bit - int64_t idx = gridDim.x - (blockIdx.x+1); // reverse order for cache hits on matmul data + int elements = logits.num_per_128(); + size_t idx = gridDim.x - (blockIdx.x+1); // reverse order for cache hits on matmul data int ix = targets[idx]; // softmax (reading B * T * V, same logits read again below, hopefully still in cache) @@ -81,7 +81,7 @@ __global__ void __launch_bounds__(1024, MAX_1024_THREADS_BLOCKS) // calculate the probability needed for the loss and update (single-threaded) if(threadIdx.x == 0) { - float prob = expf((float)logits[idx * P + ix] - sp.Offset) * sp.Scale; + float prob = expf(logits.get_scalar(idx * P + ix) - sp.Offset) * sp.Scale; losses[idx] -= logf(prob); } @@ -93,57 +93,60 @@ __global__ void __launch_bounds__(1024, MAX_1024_THREADS_BLOCKS) // calculate the gradients directly, saves bandwidth from probs during training // but also supports writing probs for inference-only and debugging - const floatX* logits_vec = logits + idx * P; - for (int i = threadIdx.x; i < V/x128::size; i += blockDim.x) { + tensor128 dlogits128 = new_tensor128(dlogits, true); + for (int i = threadIdx.x; i < V/elements; i += blockDim.x) { // this is the 2nd read of logits after the one in prepare_softmax2 // it will be overwritten by the logits gradients which is when we reduce cache persistence - x128 packed_logits_vec = load128(logits_vec + i * x128::size); // rely on cs of store128cs - x128 packed_probs; - for(int k = 0; k < x128::size; ++k) { - int element = i*x128::size + k; - float prob = expf((float)packed_logits_vec[k] - sp.Offset) * sp.Scale; + auto logits128 = load_tensor128(logits, idx * P + i * elements, false, true); + x128 packed_probs; // todo - unused but might be read on CPU in the future so not scaling (???) + for(int k = 0; k < elements; ++k) { + int element = i*elements + k; + float prob = expf(logits128.get(k) - sp.Offset) * sp.Scale; packed_probs[k] = (floatX)prob; float indicator = (element == ix) ? 1.0f : 0.0f; - packed_logits_vec[k] = (floatX)((prob - indicator) * dloss); + dlogits128.set(k, (prob - indicator) * dloss); } - if (WriteDLogits){ + if constexpr (WriteDLogits) { // reduce cache persistence for the overwritten logits // to maximise probability that logits remain in cache between prepare_softmax and here - store128cs(logits + idx * P + i * x128::size, packed_logits_vec); + dlogits128.store(idx * P + i * elements, true); } if (WriteProbs) { - store128(probs + idx * P + i * x128::size, packed_probs); + store128(probs + idx * P + i * elements, packed_probs); } } - // handle remaining elements after the last multiple of x128::size + // handle remaining elements after the last multiple of the number of elements // e.g. if V = 8003, and x128::size = 8, we need to handle the last 3 elements - int unaligned_start = V & ~(x128::size - 1); // round down to multiple of x128::size + int unaligned_start = V & ~(elements - 1); // round down to multiple of x128::size for (int i = threadIdx.x + unaligned_start; i < V; i++) { - float prob = expf((float)logits_vec[i] - sp.Offset) * sp.Scale; + float prob = expf(logits.get_scalar(idx * P + i) - sp.Offset) * sp.Scale; float indicator = (i == ix) ? 1.0f : 0.0f; float dlogit = (prob - indicator) * dloss; if (WriteDLogits){ - __stcs(logits + idx * P + i, (floatX)dlogit); + floatX dlogitX = dlogits.set_scalar(idx * P + i, dlogit); // write to memory + dlogits128.add_value_stats(dlogit, dlogitX); // add to absmax stats etc. } if (WriteProbs) { probs[idx * P + i] = (floatX)prob; } } + if constexpr (WriteDLogits) { + dlogits128.update_absmax(1); + } } // ---------------------------------------------------------------------------- // kernel launchers // replaces logits with logit gradients -template -void fused_classifier(Type* logits, float* losses, +template +void fused_classifier(tensorX dlogits, tensorX logits, tensor32 losses, const float dloss, const int* targets, - int B, int T, int V, int P, std::bool_constant write_dlogits, cudaStream_t stream) { + int BT, int V, int P, std::bool_constant write_dlogits, cudaStream_t stream=main_stream) { NVTX_RANGE_FN(); const int block_size = 1024; - const int N = B * T; - const int grid_size = N; - fused_classifier_kernel5<<>>(logits, losses, (floatX*)NULL, dloss, targets, B, T, V, P, write_dlogits); + const int grid_size = BT; + fused_classifier_kernel5<<>>(dlogits, logits, losses, (floatX*)NULL, dloss, targets, V, P, write_dlogits); cudaCheck(cudaGetLastError()); } diff --git a/llmc/gelu.cuh b/llmc/gelu.cuh index cd5c297b6..1c0af0f11 100644 --- a/llmc/gelu.cuh +++ b/llmc/gelu.cuh @@ -10,57 +10,85 @@ // CUDA kernels #define GELU_SCALING_FACTOR sqrtf(2.0f / M_PI) -__global__ void gelu_forward_kernel2(floatX* out, const floatX* inp) { - int idx = (blockIdx.x * blockDim.x + threadIdx.x) * x128::size; +template +__global__ void gelu_forward_kernel2(TensorGPU out, TensorGPU inp) { + int idx = (blockIdx.x * blockDim.x + threadIdx.x) * inp.num_per_128(); - x128 packed_out; - x128 packed_inp = load128cs(inp + idx); // load and do not keep in cache - for(int k = 0; k < packed_inp.size; ++k) { - float xi = (float)packed_inp[k]; + auto out128 = new_tensor128(out); + auto inp128 = load_tensor128(inp, idx, true); + for(int k = 0; k < inp.num_per_128(); ++k) { + float xi = inp128.get(k); float cube = 0.044715f * xi * xi * xi; - packed_out[k] = (floatX)(0.5f * xi * (1.0f + tanhf(GELU_SCALING_FACTOR * (xi + cube)))); + + float tanh_in_out = GELU_SCALING_FACTOR * (xi + cube); + #if !defined(PRECISE_GELU_TANH) && !defined(ENABLE_FP32) && __CUDA_ARCH__ >= 750 + asm ("tanh.approx.f32 %0,%1;" : "=f"(tanh_in_out) : "f"(tanh_in_out)); + #else + tanh_in_out = tanhf(tanh_in_out); + #endif + + // the following uses FMUL+FMA instead of FMUL+FADD+FMUL for "0.5f * x * (1.0f + tanh_out)" + float half_xi = 0.5f * xi; + out128.set(k, half_xi * tanh_in_out + half_xi); } - // store instead of storecs (without cache streaming) in case it is useful for the - // data to be in the cache for the next operation after this GeLU - store128(out + idx, packed_out); + out128.store(idx, false); + out128.update_absmax(1); } -__global__ void gelu_backward_inplace_kernel(floatX* d_in_out, const floatX* inp) { - int idx = (blockIdx.x * blockDim.x + threadIdx.x) * x128::size; +//template +template +__global__ void gelu_backward_kernel(TensorGPU dinp, TensorGPU dout, TensorGPU inp) { + int idx = (blockIdx.x * blockDim.x + threadIdx.x) * dout.num_per_128(); - x128 packed_dinp; - x128 packed_inp = load128cs(inp + idx); - x128 packed_dout = load128(d_in_out + idx); - for (int k = 0; k < packed_inp.size; ++k) { - float x = (float)packed_inp[k]; + auto dinp128 = new_tensor128(dinp); + auto inp128 = load_tensor128(inp, idx, true); + auto dout128 = load_tensor128(dout, idx); + for (int k = 0; k < dout.num_per_128(); ++k) { + float x = inp128.get(k); float cube = 0.044715f * x * x * x; - float tanh_arg = GELU_SCALING_FACTOR * (x + cube); - float tanh_out = tanhf(tanh_arg); - float coshf_out = coshf(tanh_arg); - float sech_out = 1.0f / (coshf_out * coshf_out); - float local_grad = 0.5f * (1.0f + tanh_out) + x * 0.5f * sech_out * GELU_SCALING_FACTOR * (1.0f + 3.0f * 0.044715f * x * x); - packed_dinp[k] = (floatX)(local_grad * (float)packed_dout[k]); + + float tanh_in_out = GELU_SCALING_FACTOR * (x + cube); + #if !defined(PRECISE_GELU_TANH) && !defined(ENABLE_FP32) && __CUDA_ARCH__ >= 750 + asm ("tanh.approx.f32 %0,%1;" : "=f"(tanh_in_out) : "f"(tanh_in_out)); + #else + tanh_in_out = tanhf(tanh_in_out); + #endif + + float sech_out = 1.0f - (tanh_in_out * tanh_in_out); + float local_grad = 0.5f * ((1.0f + tanh_in_out) + x * sech_out * GELU_SCALING_FACTOR * (1.0f + 3.0f * 0.044715f * x * x)); + float result = local_grad * dout128.get(k); + dinp128.set(k, result); } - store128(d_in_out + idx, packed_dinp); + dinp128.store(idx, false); + dinp128.update_absmax(1); } // ---------------------------------------------------------------------------- // kernel launchers - -void gelu_forward(floatX* out, const floatX* inp, int N, cudaStream_t stream) { +template +void gelu_forward(TensorGPU out, TensorGPU inp, cudaStream_t stream=main_stream) { NVTX_RANGE_FN(); - const int block_size = 512; - assert(N % (block_size * x128::size) == 0); - const int grid_size = CEIL_DIV(N, block_size * x128::size); + const int block_size = 256; + assert(inp.num_elements % (block_size * inp.num_per_128()) == 0); + + const int grid_size = CEIL_DIV(inp.num_elements, block_size * inp.num_per_128()); gelu_forward_kernel2<<>>(out, inp); cudaCheck(cudaGetLastError()); } -void gelu_backward_inplace(floatX* d_in_out, const floatX* inp, const int N, cudaStream_t stream) { +template +void gelu_backward(TensorGPU dinp, TensorGPU dout, TensorGPU inp, cudaStream_t stream=main_stream) { NVTX_RANGE_FN(); - const int block_size = 128; - assert(N % (block_size * x128::size) == 0); - const int grid_size = CEIL_DIV(N, block_size * x128::size); - gelu_backward_inplace_kernel<<>>(d_in_out, inp); + const int block_size = 256; + const int grid_size = CEIL_DIV(inp.num_elements, block_size * inp.num_per_128()); + gelu_backward_kernel<<>>(dinp, dout, inp); cudaCheck(cudaGetLastError()); } + +void gelu_forward_fp8(tensor8 out, tensor8 inp, cudaStream_t stream=main_stream) { + gelu_forward(out, inp, stream); +} + +void gelu_backward_fp8(tensor8e5 dinp, tensor8e5 dout, tensor8 inp, cudaStream_t stream=main_stream) { + gelu_backward(dinp, dout, inp, stream); +} \ No newline at end of file diff --git a/llmc/global_norm.cuh b/llmc/global_norm.cuh index e0e23b08a..53a0a7490 100644 --- a/llmc/global_norm.cuh +++ b/llmc/global_norm.cuh @@ -1,3 +1,5 @@ +// TODO - BUGGED - just committing my WIP, not sure why grad norm is zero, probably something silly! + /* Global norm, used in gradient clipping */ @@ -11,79 +13,81 @@ Global norm, used in gradient clipping // ---------------------------------------------------------------------------- // CUDA kernels -template -__device__ float global_norm_squared_for_range(const T* data, size_t count) { - size_t index = blockIdx.x * blockDim.x + threadIdx.x; - size_t grid_width = blockDim.x * gridDim.x; +__device__ float global_norm_tensors_loop(size_t idx, unsigned int stride, int num_params_tensors, unsigned int shard_idx) { float accumulator = 0.f; - for(size_t i = index; i < count; i += grid_width) { - accumulator += (float)data[i] * (float)data[i]; - } - // block-level reduce - return blockReduce(accumulator); -} + int opt_m_spec_id = 2 * num_params_tensors - 1; // -1 as it gets incremented at the start of the loop below + int last_opt_m_id = 3 * num_params_tensors - 1; // opt_m is fully sharded with ZeRO 1 so we use it as a reference + + while (true) { + size_t current_end; + // optimized critical path loop to iterate over tensors: only 8 SASS instructions! + // 3 SETP, 2 BRA, 1 IADD3, 1 IMAD, and of course 1 LDG.E.LTC256B.64 + do { + opt_m_spec_id++; + if (opt_m_spec_id > last_opt_m_id) return accumulator; // return and write the result to memory + + // on A100+ we can prefetch 256B (32 values) into the L2, on older GPUs just use a regular load + // (this improved DRAM utilization from ~81.5% to ~83.5% on my H100 PCIe) + #if __CUDA_ARCH__ < 800 + current_end = tensor_end_element_ptr[opt_m_spec_id]; + #else + asm("ld.global.L2::256B.u64 {%0}, [%1];" : "=l"(current_end) : "l"(tensor_end_element_ptr + opt_m_spec_id)); + #endif + } while (idx >= current_end); -template -__global__ void global_norm_squared_kernel(float* out, const T* data, size_t count, ptrdiff_t stride) { - float block_sum = global_norm_squared_for_range(data + blockIdx.y * stride, count); - // each block accumulates its partial sum to out[out_index] - // we want to avoid using atomic add here so we combine this kernel with another kernel call - // that sums up the partial block sums - if(threadIdx.x == 0) { - size_t out_index = blockIdx.y * gridDim.x + blockIdx.x; - out[out_index] = out[out_index] + block_sum; + // offset is 32-bit (we check parameters tensors have less than 4B elements in add_tensor_spec) + size_t current_start = tensor_specs_ptr[opt_m_spec_id].start_element; + unsigned int offset = (idx - current_start) + (shard_idx * tensor_specs_ptr[opt_m_spec_id].num_elements); + + int grad_spec_id = opt_m_spec_id - num_params_tensors; + TensorGPU grad_tensor = tensor_specs_ptr[grad_spec_id]; + + __syncthreads(); // todo - check that this does improve performance (better memory locality) + while (idx < current_end) { // todo - profile number of iterations and adding an inner loop + auto grad128 = load_tensor128(grad_tensor, offset, false, true); + for (int k = 0; k < grad_tensor.num_per_128(); k++) { + float grad = grad128.get(k); + accumulator += grad * grad; + } + idx += stride; + offset += stride; + } } } -__global__ void global_norm_aggregate_kernel(float* out, size_t grid_size) { - size_t index = threadIdx.x; - // grab block sums from the previous kernel, use 0. as the neutral sum element - float block_sum = (index < grid_size) ? out[index] : 0.f; - float sum = blockReduce(block_sum); - if(threadIdx.x == 0) { - out[0] = sum; // out[0] ends up with the final norm squared +// currently assumes all gradients are the same type (simplified adamw_update_everything) +// ZeRO 1 should use shard_idx, while DPP and ZeRO 2/3 should simply set it to 0 +template +__global__ void __launch_bounds__(256, MAX_THREADS/256) global_norm_tensors_kernel(float* out, int num_params_tensors, unsigned int shard_idx) { + constexpr size_t block_size = 256; + constexpr size_t iteration_size = Packed128::size; + unsigned int stride = gridDim.x * blockDim.x * iteration_size; + size_t idx = (blockIdx.x * block_size + threadIdx.x) * iteration_size; + + float accumulator = global_norm_tensors_loop(idx, stride, num_params_tensors, shard_idx); + + float output = blockReduce(accumulator); + if (threadIdx.x == 0) { + out[blockIdx.x] = output; } } // ---------------------------------------------------------------------------- // kernel launcher -// Helper function determines the maximum number of block sums -int get_max_num_block_sums(int* num_slices_all, int numel) { - // NOTE: this needs to be kept in sync with `global_norm_squared` below. - const int block_size = 512; +template +void global_norm_tensors(float* out, int gpu_process_rank, cudaStream_t stream=main_stream) { + const int block_size = 256; const int grid_size = deviceProp.maxThreadsPerMultiProcessor * deviceProp.multiProcessorCount / block_size; - assert(grid_size > 0); - int max_num_block_sums = 0; - for (int i = 0; i < numel; i++) { - int num_slices = num_slices_all[i]; - const int gx = CEIL_DIV(grid_size, num_slices); - const int gy = num_slices; - max_num_block_sums = max(max_num_block_sums, gx * gy); - } - - return max_num_block_sums; -} - -template -void global_norm_squared(float* out, const T* values, size_t count, ptrdiff_t stride, int num_slices, int max_num_block_sums, bool reset, cudaStream_t stream) { - const int block_size = 512; - // launch just enough blocks to fill the grid. deliberately no DIV_CEIL. - // having one block less than possible is a tiny performance hit, having - // one block too many is catastrophic, since it only can start once all the other - // blocks finish. anyway, I think cuda_threads_per_SM should be a multiple of 512 - // on all gpus, so the division really is going to be exact. - const int grid_size = deviceProp.maxThreadsPerMultiProcessor * deviceProp.multiProcessorCount / block_size; - assert(grid_size > 0); // gives a better error than letting the call below fail - const int gx = CEIL_DIV(grid_size, num_slices); - const int gy = num_slices; + int num_params_tensors = tensors_start[PARAMETER+1]; + int num_shards_opt = tensor_specs[tensors_start[PARAMETER_OPT_M]].num_shards; + int num_shards_grad = tensor_specs[tensors_start[PARAMETER_GRAD]].num_shards; + int num_shards = num_shards_opt / num_shards_grad; // should work for both DPP and ZeRO 1/2/3 + int shard_idx = gpu_process_rank % num_shards; - assert(gx * gy < 1024); // we want to later accumulate the block sums in a single block - - if (reset) { - cudaCheck(cudaMemsetAsync(out, 0, max_num_block_sums * sizeof(float), stream)); - } - global_norm_squared_kernel<<>>(out, values, count, stride); + global_norm_tensors_kernel<<>>(out, num_params_tensors, shard_idx); cudaCheck(cudaGetLastError()); -} + global_sum_deterministic(out, out, grid_size, stream); + cudaCheck(cudaGetLastError()); +} \ No newline at end of file diff --git a/llmc/layernorm.cuh b/llmc/layernorm.cuh index 9777d0658..db6892074 100644 --- a/llmc/layernorm.cuh +++ b/llmc/layernorm.cuh @@ -13,93 +13,30 @@ E.g., the layernorms are connected to the residuals so we += in layernorm backwa // llmc internal imports #include "cuda_common.h" #include "cuda_utils.cuh" +#include "tensor.cuh" // ---------------------------------------------------------------------------- // CUDA kernels -__global__ void layernorm_forward_kernel3(floatX* __restrict__ out, float* __restrict__ mean, float* __restrict__ rstd, - const floatX* __restrict__ inp, const floatX* __restrict__ weight, - const floatX* __restrict__ bias, int N, int C) { - int lane_id = threadIdx.x % WARP_SIZE; - int warp_id = threadIdx.x / WARP_SIZE; - int num_warps = blockDim.x / WARP_SIZE; +template +__global__ void layernorm_forward_kernel6(TensorGPU out, tensor32 mean, tensor32 rstd, + tensorX inp, tensorX weight, + tensorX bias, int N, int C) { + // Note that blockDim.x must be WARP_SIZE=32 but we don't want to pay the cost of assert() here + int idx = blockIdx.x * blockDim.y + threadIdx.y; // non-standard: threadIdx.x is used for c + if(idx >= N) { return; } - int idx = blockIdx.x * num_warps + warp_id; - if(idx >= N) { return; } // guard - - // the row of input that this group of threads is responsible for - const floatX* x = inp + idx * C; - - // mean - float sum = 0.0f; - for (int i = lane_id; i < C; i += WARP_SIZE) { - sum += (float)x[i]; - } - sum = warpReduceSum(sum); - float m = sum / C; - if(lane_id == 0 && mean != nullptr) { - __stcs(mean + idx, m); - } - - // rstd - sum = 0.0f; - for (int i = lane_id; i < C; i += WARP_SIZE) { - float diff = (float)x[i] - m; - sum += diff * diff; - } - sum = warpReduceSum(sum); - float s = rsqrtf(sum / C + 1e-5f); - if(lane_id == 0 && rstd != nullptr) { - __stcs(rstd + idx, s); - } - - // final normalization and scaling by weight/bias - floatX* o = out + idx * C; - for (int c = lane_id; c < C; c += WARP_SIZE) { - // load and store using the .cs "streaming" hint to the compiler, - // indicating that this data will not be reused soon, and can be streamed through the caches - // this allows the threads to get more cache-hits for the (shared) weight and bias parameters - float n = s * ((float)__ldcs(x+c) - m); - __stcs(o+c, (floatX)(n * (float)weight[c] + (float)bias[c])); - } -} - -__global__ void layernorm_forward_kernel6(floatX* __restrict__ out, float* __restrict__ mean, float* __restrict__ rstd, - const floatX* __restrict__ inp, const floatX* __restrict__ weight, - const floatX* __restrict__ bias, int N, int C) { - assert(blockDim.x == WARP_SIZE); - - // load weights and biases into shared memory - // do this before we allow any threads to exit! + // load/store128 sometimes generated multiple instructions with floatX, so keep it as x128 extern __shared__ char* params[]; - // load128/store128 sometimes generated multiple instructions when the types here were floatX*, so - // let's keep everything as x128 - x128* s_weight = reinterpret_cast(params); - x128* s_bias = reinterpret_cast(params) + (C / x128::size); - x128* s_in = reinterpret_cast(params) + ((2 + threadIdx.y) * C / x128::size); - - int sidx = (threadIdx.x + WARP_SIZE * threadIdx.y) * x128::size; - for(int i = sidx; i < C; i += blockDim.y * WARP_SIZE * x128::size) { - s_weight[i/x128::size] = load128(weight + i); - s_bias[i/x128::size] = load128(bias + i); - } - __syncthreads(); + x128* s_in = reinterpret_cast(params) + (threadIdx.y * C / x128::size); - int idx = blockIdx.x * blockDim.y + threadIdx.y; - if(idx >= N) { return; } // guard - - // adjust pointers to current token - inp += idx * C; - out += idx * C; - - const float eps = 1e-5f; float sum = 0.0f; for(int c = threadIdx.x * x128::size; c < C; c += WARP_SIZE * x128::size) { - const x128 in_data = load128cs(inp + c); + auto inp128 = load_tensor128(inp, idx * C + c, true); for(int k = 0; k < x128::size; ++k) { - sum += (float)in_data[k]; + sum += inp128.get(k); } - s_in[c / x128::size] = in_data; + s_in[c / x128::size] = inp128.get128(); } sum = warpReduceSum(sum); @@ -114,74 +51,58 @@ __global__ void layernorm_forward_kernel6(floatX* __restrict__ out, float* __res } v = warpReduceSum(v) / C; + const float eps = 1e-5f; // todo - is this optimal / theoretically justified? float s = rsqrtf(v + eps); + auto out128 = new_tensor128(out); for(int c = threadIdx.x * x128::size; c < C; c += WARP_SIZE * x128::size) { const x128 in_data = s_in[c / x128::size]; - const x128 w = s_weight[c / x128::size]; - const x128 b = s_bias[c / x128::size]; - x128 out_data; + auto w128 = load_tensor128(weight, c); + auto b128 = load_tensor128(bias, c); for(int k = 0; k < x128::size; ++k) { float n = s * ((float)in_data[k] - m); // normalized output - float o = n * (float)w[k] + (float)b[k]; // scale and shift it - out_data[k] = (floatX)o; + float o = n * w128.get(k) + b128.get(k); // scale and shift it + out128.set(k, o); } - - store128cs(out + c, out_data); + out128.template store_same_length(idx * C + c); } // cache the mean and rstd for the backward pass later - if(threadIdx.x == 0 && mean != nullptr) { + if(threadIdx.x == 0) { // todo - add a way to pass equivalent of null for mean/rstd to avoid store __stcs(mean + idx, m); - } - // store the rstd, no need to cache it - if(threadIdx.x == 0 && rstd != nullptr) { __stcs(rstd + idx, s); } + // update absmax + out128.update_absmax(2); } -__global__ void fused_residual_forward_kernel5(floatX* residual, floatX* normed, float* mean, float* rstd, - const floatX* inp1, const floatX* inp2, - const floatX* weight, const floatX* bias, +template +__global__ void fused_residual_forward_kernel5(tensorX residual, TensorGPU normed, tensor32 mean, tensor32 rstd, + const tensorX inp1, const TensorGPU inp2, + const tensorX weight, const tensorX bias, int N, int C) { - assert(blockDim.x == WARP_SIZE); - - // load weights and biases into shared memory - // do this before we allow any threads to exit! - extern __shared__ char* params[]; - // load128/store128 sometimes generated multiple instructions when the types here were floatX*, so - // let's keep everything as x128 - x128* s_weight = reinterpret_cast(params); - x128* s_bias = reinterpret_cast(params) + (C / x128::size); - x128* s_res = reinterpret_cast(params) + ((2 + threadIdx.y) * C / x128::size); - - int sidx = (threadIdx.x + WARP_SIZE * threadIdx.y) * x128::size; - for(int i = sidx; i < C; i += blockDim.y * WARP_SIZE * x128::size) { - s_weight[i/x128::size] = load128(weight + i); - s_bias[i/x128::size] = load128(bias + i); - } - __syncthreads(); - + // Note that blockDim.x must be WARP_SIZE=32 but we don't want to pay the cost of assert() here int idx = blockIdx.x * blockDim.y + threadIdx.y; if(idx > N) return; - // adjust pointers to current token - residual += C * idx; - normed += C * idx; - inp1 += C * idx; - inp2 += C * idx; + // load/store128 sometimes generated multiple instructions with floatX, so keep it as x128 + extern __shared__ char* params[]; + x128* s_res = reinterpret_cast(params) + (threadIdx.y * C / x128::size); + + auto residual128 = new_tensor128(residual); + auto normed128 = new_tensor128(normed); const float eps = 1e-5f; float sum = 0.0f; for(int c = threadIdx.x * x128::size; c < C; c += WARP_SIZE * x128::size) { - const x128 in1 = load128cs(inp1 + c); - const x128 in2 = load128cs(inp2 + c); - x128 out; + auto inp1_128 = load_tensor128(inp1, idx * C + c, true); + auto inp2_128 = load_tensor128(inp2, idx * C + c, true); for(int k = 0; k < x128::size; ++k) { - out[k] = (float)in1[k] + (float)in2[k]; - sum += (float)out[k]; + float out = inp1_128.get(k) + inp2_128.get(k); + residual128.set(k, out); + sum += residual128.get(k); } - store128cs(residual + c, out); - s_res[c / x128::size] = out; + residual128.store(idx * C + c, false); + s_res[c / x128::size] = residual128.get128(); } sum = warpReduceSum(sum); @@ -200,42 +121,32 @@ __global__ void fused_residual_forward_kernel5(floatX* residual, floatX* normed, for(int c = threadIdx.x * x128::size; c < C; c += WARP_SIZE * x128::size) { const x128 res = s_res[c / x128::size]; - const x128 w = s_weight[c / x128::size]; - const x128 b = s_bias[c / x128::size]; - x128 out; + auto w128 = load_tensor128(weight, c); + auto b128 = load_tensor128(bias, c); for(int k = 0; k < x128::size; ++k) { float n = s * ((float)res[k] - m); // normalized output - float o = n * (float)w[k] + (float)b[k]; // scale and shift it - out[k] = o; + float o = n * w128.get(k) + b128.get(k); // scale and shift it + normed128.set(k, o); } - - store128cs(normed + c, out); + normed128.template store_same_length(idx * C + c, false); } // cache the mean and rstd for the backward pass later if(threadIdx.x == 0) { - mean[idx] = m; - rstd[idx] = s; + __stcs(mean + idx, m); + __stcs(rstd + idx, s); } -} -__global__ void residual_forward_kernel(floatX* out, const floatX* inp1, const floatX* inp2) { - int idx = (blockIdx.x * blockDim.x + threadIdx.x) * x128::size; - - x128 packed_out; - x128 packed_inp1 = load128cs(inp1 + idx); - x128 packed_inp2 = load128cs(inp2 + idx); - for (int k = 0; k < packed_inp1.size; k++) { - packed_out[k] = (floatX)((float)packed_inp1[k] + (float)packed_inp2[k]); - } - store128(out + idx, packed_out); + // Update absmax for residual and normed tensors (typically it will skip residual as it is not FP8) + residual128.update_absmax(2); + normed128.update_absmax(2); } +template __global__ void __launch_bounds__(512, 2) // todo - any warnings on Turing with only 1024 threads? - layernorm_backward_kernel10(floatX* dinp, floatX* dweight, floatX* dbias, float* scratch, - const floatX* dout, const floatX* inp, const floatX* weight, - const float* mean, const float* rstd, - int B, int T, int C) { - int BLOCK_SIZE = blockDim.x; + layernorm_backward_kernel10(tensorX dinp_new, tensorX dinp_old, tensorX dweight, tensorX dbias, tensor32 scratch_, + TensorGPU dout, tensorX inp, tensorX weight, tensor32 mean, tensor32 rstd, + int BT, int C) { + int BLOCK_SIZE = blockDim.x; // todo - does it make any difference if this is hardcoded here? int warpsInBlock = BLOCK_SIZE / WARP_SIZE; //number of warps in block extern __shared__ float shared[]; @@ -263,22 +174,19 @@ __global__ void __launch_bounds__(512, 2) // todo - any warnings on Turing with } __syncthreads(); - for (int bt = baseIdx; bt < B * T; bt += warpsInGrid) { - const floatX* dout_bt = dout + bt * C; - const floatX* inp_bt = inp +bt * C; - floatX* dinp_bt = dinp + bt * C; + auto dinp_new128 = new_tensor128(dinp_new); - // first: two reduce operations + for (int bt = baseIdx; bt < BT; bt += warpsInGrid) { float dnorm_mean = 0.0f; float dnorm_norm_mean = 0.0f; for (int i = warpThreadIdx * x128::size; i < C; i += WARP_SIZE * x128::size) { - x128 dout128_i = load128(dout_bt + i); - x128 inp128_i = load128(inp_bt + i); - x128 weight128_i = load128(weight + i); + auto dout128_i = load_tensor128(dout, bt * C + i); + auto inp128_i = load_tensor128(inp, bt * C + i); + auto weight128_i = load_tensor128(weight, i); for (int k = 0; k < x128::size; k++) { - float dnorm_i = (float)weight128_i[k] * (float)dout128_i[k]; + float dnorm_i = weight128_i.get(k) * dout128_i.get(k); dnorm_mean += dnorm_i; - dnorm_norm_mean += dnorm_i * (float)inp128_i[k]; + dnorm_norm_mean += dnorm_i * inp128_i.get(k); } } @@ -290,16 +198,18 @@ __global__ void __launch_bounds__(512, 2) // todo - any warnings on Turing with for (int c = 0; c < iterations_C; c++) { int global_index = (warpThreadIdx * x128::size) + (c * C_per_iteration); - x128 dout128 = x128::zeros(); - x128 inp128 = x128::zeros(); - x128 dinp128 = x128::zeros(); - x128 weight128 = x128::zeros(); + tensor128 dout128; + tensor128 inp128; + tensor128 weight128; + tensor128 dinp128; if(global_index < C) { - dout128 = load128cs(dout_bt + global_index); - inp128 = load128cs(inp_bt + global_index); - dinp128 = load128(dinp_bt + global_index); - weight128 = load128(weight + global_index); + dout128 = load_tensor128(dout, bt * C + global_index, true); + inp128 = load_tensor128(inp, bt * C + global_index, true); + weight128 = load_tensor128(weight, global_index); + if constexpr (!zero_dinp_old) { + dinp128 = load_tensor128(dinp_old, bt * C + global_index); + } } for(int o = 0; o < x128::size / f128::size; ++o) { @@ -307,17 +217,17 @@ __global__ void __launch_bounds__(512, 2) // todo - any warnings on Turing with f128 dweight_f; for(int i = 0; i < f128::size; ++i) { int x = o * f128::size + i; - float dout_i = (float)dout128[x]; - float norm_bti = ((float)inp128[x] - mean_bt) * rstd_bt; + float dout_i = dout128.get(x); + float norm_bti = (inp128.get(x) - mean_bt) * rstd_bt; dbias_f[i] = dout_i; dweight_f[i] = norm_bti * dout_i; float dval = 0.0f; - dval += (float) weight128[x] * (float)dout128[x]; // term 1 + dval += weight128.get(x) * dout128.get(x); // term 1 dval -= dnorm_mean; // term 2 dval -= norm_bti * dnorm_norm_mean; // term 3 dval *= rstd_bt; // final scale - dinp128[x] = (floatX) ((float) dinp128[x] + dval); + dinp_new128.set(x, dinp128.get(x) + dval); } if (warpId != 0) { @@ -352,15 +262,20 @@ __global__ void __launch_bounds__(512, 2) // todo - any warnings on Turing with } } if(global_index < C) { - // cache in L2 as this is read by the next kernel, but bypass L1 to minimise thrashing - store128cg(dinp_bt + global_index, dinp128); + dinp_new128.store_same_length(bt * C + global_index, false); } } } - __syncthreads(); + + // if we did actually update the absmax (returns true), we already did __syncthreads() here + if (!dinp_new128.update_absmax(1)) { + __syncthreads(); + } + // Each block writes its partial sum to global memory // The last block to finish becomes responsible for summing up all the partial sums // This is done by atomically incrementing a flag (cleared to 0 before launching the kernel) + float* scratch = (float*)scratch_; unsigned int* scratchFlag = (unsigned int*)(scratch); // Increment scratch pointer by a full cacheline so that everything remains cacheline aligned scratch += 32; @@ -403,95 +318,81 @@ __global__ void __launch_bounds__(512, 2) // todo - any warnings on Turing with // convert from float/FP32 to floatX/BF16 for the final write // this is separate because it cannot use as many warps as the above (f128 vs x128) // todo - if we split this code into another kernel, we could maybe do it at the same time? + auto dbias128_out = new_tensor128(dbias); + auto dweight128_out = new_tensor128(dweight); for (int c = warpId; c < iterations_C; c += warpsInBlock) { int global_index = (warpThreadIdx * x128::size) + (c * C_per_iteration); if (global_index >= C) { break; } - x128 dbias128 = load128(dbias + global_index); - x128 dweight128 = load128(dweight + global_index); + auto dbias128 = load_tensor128(dbias, global_index); + auto dweight128 = load_tensor128(dweight, global_index); for(int o = 0; o < x128::size / f128::size; ++o) { f128 s_db = load128(dbias_shared + global_index + o * f128::size); f128 s_dw = load128(dweight_shared + global_index + o * f128::size); for(int i = 0; i < f128::size; ++i) { int x = o * f128::size + i; - dbias128[x] = (floatX)(s_db[i] + (float)dbias128[x]); - dweight128[x] = (floatX)(s_dw[i] + (float)dweight128[x]); + dbias128_out.set(x, s_db[i] + dbias128.get(x)); + dweight128_out.set(x, s_dw[i] + dweight128.get(x)); } } - store128(dbias + global_index, dbias128); - store128(dweight + global_index, dweight128); + dbias128_out.store_same_length(global_index); + dweight128_out.store_same_length(global_index); } + dbias128_out.update_absmax(1); + dweight128_out.update_absmax(1); } } // ---------------------------------------------------------------------------- // kernel launchers -// similar to `fused_residual_forward5` -void layernorm_forward(floatX* out, float* mean, float* rstd, - floatX* inp, const floatX* weight, const floatX* bias, - int B, int T, int C, cudaStream_t stream) { - NVTX_RANGE_FN(); - const int block_size = 256; +// Helper function to set the block size based on available shared memory and launch the kernel +template +void launch_layernorm_kernel(KernelFunc kernel, int N, int C, cudaStream_t stream, Args... args) { + int block_size = 256; int block_y = block_size / WARP_SIZE; - const int N = B * T; - const int grid_size = CEIL_DIV(N, block_y); - size_t smem = (2 + block_y) * C * sizeof(floatX); - - // in order to use more than 48 KiB of smem, need to call cudaFuncSetAttribute - // this may fail, in which case we fall back to the smem free implementation. - cudaCheck(cudaGetLastError()); - auto status = cudaFuncSetAttribute(layernorm_forward_kernel6, cudaFuncAttributeMaxDynamicSharedMemorySize, smem); - cudaCheck(cudaGetLastError()); - if (status == cudaSuccess) { - layernorm_forward_kernel6<<>>(out, mean, rstd, inp, weight, bias, N, C); - } else { - // fall back to the version without shared memory - const int grid_size_fb = CEIL_DIV(N * WARP_SIZE, block_size); - layernorm_forward_kernel3<<>>(out, mean, rstd, inp, weight, bias, N, C); + size_t smem = block_y * C * sizeof(floatX); + auto status = cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem); + + // if we don't have enough shared memory, try smaller block sizes down to 32 threads + // should fit on practically every modern GPU even for very large numbers of channels + // todo - do we want to manually set the shared memory vs L1 carveout as well? + while (status != cudaSuccess) { + if (block_y == 1) { + printf("ERROR: not enough shared memory for kernel\n"); + exit(EXIT_FAILURE); + } + block_y /= 2, block_size /= 2; + smem = (2 + block_y) * C * sizeof(floatX); + status = cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem); } + int grid_size = CEIL_DIV(N, block_y); + kernel<<>>(args..., N, C); cudaCheck(cudaGetLastError()); } -void residual_forward(floatX* out, const floatX* inp1, const floatX* inp2, int N, cudaStream_t stream) { +template +void layernorm_forward(TensorGPU out, tensor32 mean, tensor32 rstd, + tensorX inp, const tensorX weight, const tensorX bias, + int N, int C, cudaStream_t stream=main_stream) { NVTX_RANGE_FN(); - const int block_size = 256; - assert(N % (block_size * x128::size) == 0); - const int grid_size = CEIL_DIV(N, block_size * x128::size); - residual_forward_kernel<<>>(out, inp1, inp2); - cudaCheck(cudaGetLastError()); + launch_layernorm_kernel(layernorm_forward_kernel6, N, C, stream, out, mean, rstd, inp, weight, bias); } -void fused_residual_forward5(floatX* residual, floatX* normed, float* mean, float* rstd, - const floatX* inp1, const floatX* inp2, - const floatX* weight, const floatX* bias, - int N, int C, cudaStream_t stream) { - const int block_size = 256; - int block_y = block_size / WARP_SIZE; - const int grid_size = CEIL_DIV(N, block_y); - size_t smem = (2 + block_y) * C * sizeof(floatX); - - // in order to use more than 48 KiB of smem, need to call cudaFuncSetAttribute - // this may fail, in which case we fall back to the smem free implementation. - cudaCheck(cudaGetLastError()); - auto status = cudaFuncSetAttribute(fused_residual_forward_kernel5, cudaFuncAttributeMaxDynamicSharedMemorySize, smem); - cudaCheck(cudaGetLastError()); - if(status == cudaSuccess) { - fused_residual_forward_kernel5<<>>(residual, normed, - mean, rstd, inp1, inp2, - weight, bias, N, C); - } else { - residual_forward(residual, inp1, inp2, N*C, stream); - layernorm_forward(normed, mean, rstd, residual, weight, bias, N, 1, C, stream); - } - cudaCheck(cudaGetLastError()); +template +void fused_residual_forward5(tensorX residual, TensorGPU normed, tensor32 mean, tensor32 rstd, + tensorX inp1, TensorGPU inp2, tensorX weight, tensorX bias, + int N, int C, cudaStream_t stream=main_stream) { + NVTX_RANGE_FN(); + launch_layernorm_kernel(fused_residual_forward_kernel5, N, C, stream, residual, normed, mean, rstd, inp1, inp2, weight, bias); } -void layernorm_backward(floatX* dinp, floatX* dweight, floatX* dbias, float* scratch, - const floatX* dout, const floatX* inp, const floatX* weight, const float* mean, const float* rstd, - int B, int T, int C, cudaStream_t stream) { +template +void layernorm_backward(tensorX dinp_new, tensorX dinp_old, tensorX dweight, tensorX dbias, tensor32 scratch, + const TensorGPU dout, const tensorX inp, const tensorX weight, tensor32 mean, tensor32 rstd, + int BT, int C, cudaStream_t stream=main_stream) { NVTX_RANGE_FN(); const int block_size = 512; const int blocks_per_sm = 2; // supported on every architecture and less cache thrashing than 3 @@ -500,6 +401,10 @@ void layernorm_backward(floatX* dinp, floatX* dweight, floatX* dbias, float* scr size_t shared_mem_size = (2 * rounded_C + 2 * (block_size - 32) * f128::size) * sizeof(float); cudaCheck(cudaMemsetAsync(scratch, 0, 1 * sizeof(float), stream)); // only need to reset the flag to 0 - layernorm_backward_kernel10<<>>(dinp, dweight, dbias, scratch, dout, inp, weight, mean, rstd, B, T, C); + if (dinp_old.is_null()) { + layernorm_backward_kernel10<<>>(dinp_new, dinp_old, dweight, dbias, scratch, dout, inp, weight, mean, rstd, BT, C); + } else { + layernorm_backward_kernel10<<>>(dinp_new, dinp_old, dweight, dbias, scratch, dout, inp, weight, mean, rstd, BT, C); + } cudaCheck(cudaGetLastError()); } diff --git a/llmc/matmul.cuh b/llmc/matmul.cuh index becc372c6..830797bed 100644 --- a/llmc/matmul.cuh +++ b/llmc/matmul.cuh @@ -10,12 +10,17 @@ Matrix Multiplication, with help from cuBLASLt // GELU can be either fused (cublasLt) or non-fused (gelu.h) #include "gelu.cuh" +// todo - does this need to be included globally? +#include "copy_and_fp8.h" + // ---------------------------------------------------------------------------- // CUDA kernels -template -__global__ void matmul_backward_bias_kernel9(OutFloat* dbias, const floatX* dout, int B, int T, int OC, +template +__global__ void matmul_backward_bias_kernel9(TensorGPU dbias, TensorGPU dout, int BT, int OC, std::bool_constant) { + // todo - this kernel is way more complicated than it needs to be + // (should look at my old PR to simplify it again after this) constexpr const int bdx = 4; constexpr const int bdy = WARP_SIZE / bdx; assert(blockDim.x == bdx); @@ -25,33 +30,33 @@ __global__ void matmul_backward_bias_kernel9(OutFloat* dbias, const floatX* dout int warp_c = (int)threadIdx.y; int block_d = (int)threadIdx.z; - const int OC_per_warp = bdy * x128::size; // 64 at BF16 + const int OC_per_warp = bdy * Packed128::size; // 64 at BF16 - int local_oc = warp_c * x128::size; + int local_oc = warp_c * Packed128::size; int global_oc = blockIdx.x * OC_per_warp + local_oc; int local_bt = warp_d + bdx * block_d; int bt_per_block = bdx * blockDim.z; - float accumulators[x128::size]; - for (int k = 0; k < x128::size; k++) { + float accumulators[Packed128::size]; + for (int k = 0; k < Packed128::size; k++) { accumulators[k] = 0.0f; } if(global_oc < OC) { // sum up over all bt within registers - for (int idx = blockIdx.y * bt_per_block + local_bt; idx < B * T; idx += gridDim.y * bt_per_block) { - x128 packed_dout = load128(dout + global_oc + idx*OC); - for (int k = 0; k < x128::size; k++) { - accumulators[k] += (float)packed_dout[k]; + for (int idx = blockIdx.y * bt_per_block + local_bt; idx < BT; idx += gridDim.y * bt_per_block) { + auto dout128 = load_tensor128(dout, global_oc + idx*OC); + for (int k = 0; k < Packed128::size; k++) { + accumulators[k] += dout128.get(k); } } } - __shared__ float sub_results[x128::size][WARP_SIZE][bdy]; + __shared__ float sub_results[Packed128::size][WARP_SIZE][bdy]; // reduce within-warp results - for (int k = 0; k < x128::size; k++) { + for (int k = 0; k < Packed128::size; k++) { float v = accumulators[k]; v += __shfl_down_sync(0xffffffff, v, 1, 4); v += __shfl_down_sync(0xffffffff, v, 2, 4); @@ -62,7 +67,7 @@ __global__ void matmul_backward_bias_kernel9(OutFloat* dbias, const floatX* dout __syncthreads(); // block-wide reductions - for (int k = block_d; k < x128::size; k += blockDim.z) { + for (int k = block_d; k < Packed128::size; k += blockDim.z) { float a = 0.f; for (int r = warp_d; r < blockDim.z; r += bdx) { float v = sub_results[k][r][warp_c]; @@ -106,17 +111,14 @@ __global__ void reduce_add_sum_kernel(floatX* dst, const float* src, size_t n, s // Wrapper around cublasLtMatmul that is meant to support everything we need in llm.c // https://docs.nvidia.com/cuda/cublas/#cublasltmatmul -void matmul_cublaslt(floatX* d, const floatX* a, const floatX* b, const floatX* bias, +void matmul_cublaslt(tensorX d, const tensorX a, const tensorX b, const tensorX bias, int m, int n, int k, cudaStream_t stream=0, bool transA=true, bool transB=false, int batch_count=0, size_t strideA=0, size_t strideB=0, size_t strideOut=0, - bool accumulate=false, floatX* pre_gelu=NULL, bool backward=false) + bool accumulate=false, tensorX pre_gelu=null_tensorX, bool backward=false) { NVTX_RANGE_FN(); - bool has_bias = (bias != NULL); - bool has_gelu = (pre_gelu != NULL); - // check alignment (some modes work unaligned but it always best to be aligned for performance) - if(((uintptr_t)a % 16) != 0 || ((uintptr_t)b % 16) != 0 || ((uintptr_t)d % 16) != 0 || ((uintptr_t)bias % 16) != 0) { + if(((uintptr_t)a.data_ptr % 16) != 0 || ((uintptr_t)b.data_ptr % 16) != 0 || ((uintptr_t)d.data_ptr % 16) != 0 || ((uintptr_t)bias.data_ptr % 16) != 0) { printf("All cuBLASLt pointers must be aligned!\n"); exit(EXIT_FAILURE); } @@ -149,8 +151,7 @@ void matmul_cublaslt(floatX* d, const floatX* a, const floatX* b, const floatX* } else { cublasCheck(cublasLtMatrixLayoutCreate(&BLayout, CUBLAS_LOWP, k, n, k)); } - // cuBLASLt requires C in FP8 mode to be BF16 or FP32... (sigh) - cublasCheck(cublasLtMatrixLayoutCreate(&CLayout, (sizeof(floatX) == 1) ? CUDA_R_16BF : CUBLAS_LOWP, m, n, m)); + cublasCheck(cublasLtMatrixLayoutCreate(&CLayout, CUBLAS_LOWP, m, n, m)); cublasCheck(cublasLtMatrixLayoutCreate(&DLayout, CUBLAS_LOWP, m, n, m)); // Strided Batched GEMM (used for non-flash attention, equivalent to cublasGemmStridedBatchedEx) @@ -173,30 +174,31 @@ void matmul_cublaslt(floatX* d, const floatX* a, const floatX* b, const floatX* // setup epilogue and associated pointers for bias & gelu cublasLtEpilogue_t epilogue; - if (has_gelu) { + if (pre_gelu.enabled()) { int64_t gelu_ld = m; // todo - is this affected by anything else? cublasCheck(cublasLtMatmulDescSetAttribute(operationDesc, CUBLASLT_MATMUL_DESC_EPILOGUE_AUX_LD, &gelu_ld, sizeof(gelu_ld))); - cublasCheck(cublasLtMatmulDescSetAttribute(operationDesc, CUBLASLT_MATMUL_DESC_EPILOGUE_AUX_POINTER, &pre_gelu, sizeof(pre_gelu))); + cublasCheck(cublasLtMatmulDescSetAttribute(operationDesc, CUBLASLT_MATMUL_DESC_EPILOGUE_AUX_POINTER, &pre_gelu.data_ptr, sizeof(pre_gelu.data_ptr))); if (backward) { - assert(!has_bias); // we shouldn't have any backward matmuls that use both GELU and bias + assert(!bias.enabled()); // we shouldn't have any backward matmuls that use both GELU and bias epilogue = CUBLASLT_EPILOGUE_DGELU; } else { - epilogue = has_bias ? CUBLASLT_EPILOGUE_GELU_AUX_BIAS : CUBLASLT_EPILOGUE_GELU_AUX; + epilogue = bias.enabled() ? CUBLASLT_EPILOGUE_GELU_AUX_BIAS : CUBLASLT_EPILOGUE_GELU_AUX; } - } else if(has_bias){ + } else if(bias.enabled()){ epilogue = backward ? CUBLASLT_EPILOGUE_BGRADB : CUBLASLT_EPILOGUE_BIAS; } else { epilogue = CUBLASLT_EPILOGUE_DEFAULT; } cublasCheck(cublasLtMatmulDescSetAttribute(operationDesc, CUBLASLT_MATMUL_DESC_EPILOGUE, &epilogue, sizeof(epilogue))); - if (has_bias) { + if (bias.enabled()) { // cuBLASLt requires bias in FP8 mode to be BF16... (sigh) cublasDataType_t bias_data_type = (sizeof(floatX) == 1) ? CUDA_R_16BF : CUBLAS_LOWP; // force BF16 bias for FP8 mode cublasCheck(cublasLtMatmulDescSetAttribute(operationDesc, CUBLASLT_MATMUL_DESC_BIAS_DATA_TYPE, &bias_data_type, sizeof(bias_data_type))); - cublasCheck(cublasLtMatmulDescSetAttribute(operationDesc, CUBLASLT_MATMUL_DESC_BIAS_POINTER, &bias, sizeof(bias))); + cublasCheck(cublasLtMatmulDescSetAttribute(operationDesc, CUBLASLT_MATMUL_DESC_BIAS_POINTER, &bias.data_ptr, sizeof(bias.data_ptr))); } + // set scale type to FP32 (needs to be FP16 if and only if using CUBLAS_COMPUTE_16F, so it's FP32 even for FP8!) cublasDataType_t scale_type = CUDA_R_32F; cublasCheck(cublasLtMatmulDescSetAttribute(operationDesc, CUBLASLT_MATMUL_DESC_SCALE_TYPE, &scale_type, sizeof(scale_type))); @@ -205,7 +207,101 @@ void matmul_cublaslt(floatX* d, const floatX* a, const floatX* b, const floatX* cublasLtMatmulAlgoGetHeuristic(cublaslt_handle, operationDesc, ALayout, BLayout, CLayout, DLayout, preference, 1, &heuristic, &returnedResults); if (returnedResults == 0) { - printf("No cuBLASLt algorithm: m: %d, n: %d, k: %d, bias: %d\n", n, m, k, has_bias); + printf("No cuBLASLt algorithm: m: %d, n: %d, k: %d, bias: %d\n", n, m, k, bias.enabled()); + exit(EXIT_FAILURE); + } + + // set whether to accumulate (i.e. D += C) or not - note this isn't considered in algorithm selection (?!) + const float alpha = 1.0f, beta = accumulate ? 1.0f : 0.0f; + + // call the matmul + cublasCheck(cublasLtMatmul(cublaslt_handle, operationDesc, + &alpha, a, ALayout, b, BLayout, &beta, d, CLayout, d, DLayout, + &heuristic.algo, cublaslt_workspace, cublaslt_workspace_size, stream)); + + #ifdef FAKE_LOW_PRECISION + update_absmax(d, false); // fake FP8 requires the absmax to work (cuBLAS can't do it for BF16) + #endif + + // cleanups + cublasCheck(cublasLtMatmulPreferenceDestroy(preference)); + cublasCheck(cublasLtMatmulDescDestroy(operationDesc)); + cublasCheck(cublasLtMatrixLayoutDestroy(ALayout)); + cublasCheck(cublasLtMatrixLayoutDestroy(BLayout)); + cublasCheck(cublasLtMatrixLayoutDestroy(CLayout)); + cublasCheck(cublasLtMatrixLayoutDestroy(DLayout)); + cudaCheck(cudaGetLastError()); +} + +#ifdef ENABLE_FP8 +template +void matmul_cublaslt_fp8(TensorGPU d, const TensorGPU a, const TensorGPU b, const tensorX bias, + int m, int n, int k, cudaStream_t stream=main_stream, + bool accumulate=false, bool backward=false) +{ + NVTX_RANGE_FN(); + if(((uintptr_t)a.data_ptr % 16) != 0 || ((uintptr_t)b.data_ptr % 16) != 0 || ((uintptr_t)d.data_ptr % 16) != 0 || ((uintptr_t)bias.data_ptr % 16) != 0) { + printf("All cuBLASLt pointers must be aligned!\n"); + exit(EXIT_FAILURE); + } + + // create the operation descriptor + cublasLtMatmulDesc_t operationDesc; + cublasCheck(cublasLtMatmulDescCreate(&operationDesc, cublas_compute, CUDA_R_32F)); + + cublasOperation_t opTranspose = CUBLAS_OP_T, opNoTranspose = CUBLAS_OP_N; + cublasCheck(cublasLtMatmulDescSetAttribute(operationDesc, CUBLASLT_MATMUL_DESC_TRANSA, &opTranspose, sizeof(opTranspose))); + cublasCheck(cublasLtMatmulDescSetAttribute(operationDesc, CUBLASLT_MATMUL_DESC_TRANSB, &opNoTranspose, sizeof(opNoTranspose))); + + // define matrix layouts + cublasLtMatrixLayout_t ALayout, BLayout, CLayout, DLayout; + cublasDataType_t typeA = std::is_same::value ? CUDA_R_8F_E4M3 : CUDA_R_8F_E5M2; + cublasDataType_t typeB = std::is_same::value ? CUDA_R_8F_E4M3 : CUDA_R_8F_E5M2; + cublasDataType_t typeD = std::is_same::value ? CUBLAS_LOWP : + (std::is_same::value ? CUDA_R_8F_E4M3 : CUDA_R_8F_E5M2); + + cublasCheck(cublasLtMatrixLayoutCreate(&ALayout, typeA, k, m, k)); // always transposed for FP8 + cublasCheck(cublasLtMatrixLayoutCreate(&BLayout, typeB, k, n, k)); // never transposed for FP8 + cublasCheck(cublasLtMatrixLayoutCreate(&CLayout, CUBLAS_LOWP, m, n, m)); // must be BF16 for accumulation in cuBLASLt + cublasCheck(cublasLtMatrixLayoutCreate(&DLayout, typeD, m, n, m)); + + // setup epilogue and associated pointers for bias + cublasLtEpilogue_t epilogue = CUBLASLT_EPILOGUE_DEFAULT; + if(bias.data_ptr != NULL) { + epilogue = backward ? CUBLASLT_EPILOGUE_BGRADB : CUBLASLT_EPILOGUE_BIAS; + cublasDataType_t bias_data_type = CUBLAS_LOWP; // BF16 bias for FP8 mode + cublasCheck(cublasLtMatmulDescSetAttribute(operationDesc, CUBLASLT_MATMUL_DESC_BIAS_DATA_TYPE, &bias_data_type, sizeof(bias_data_type))); + cublasCheck(cublasLtMatmulDescSetAttribute(operationDesc, CUBLASLT_MATMUL_DESC_BIAS_POINTER, &bias.data_ptr, sizeof(bias.data_ptr))); + } + cublasCheck(cublasLtMatmulDescSetAttribute(operationDesc, CUBLASLT_MATMUL_DESC_EPILOGUE, &epilogue, sizeof(epilogue))); + + // FP8 scale factors and absmax pointers + float* a_descale_ptr = a.scale_descale_ptr + 1; + float* b_descale_ptr = b.scale_descale_ptr + 1; + cublasCheck(cublasLtMatmulDescSetAttribute(operationDesc, CUBLASLT_MATMUL_DESC_A_SCALE_POINTER, &a_descale_ptr, sizeof(float*))); + cublasCheck(cublasLtMatmulDescSetAttribute(operationDesc, CUBLASLT_MATMUL_DESC_B_SCALE_POINTER, &b_descale_ptr, sizeof(float*))); + if (sizeof(Td) == 1) { + cublasCheck(cublasLtMatmulDescSetAttribute(operationDesc, CUBLASLT_MATMUL_DESC_D_SCALE_POINTER, &d.scale_descale_ptr, sizeof(float*))); + cublasCheck(cublasLtMatmulDescSetAttribute(operationDesc, CUBLASLT_MATMUL_DESC_AMAX_D_POINTER, &d.absmax_ptr, sizeof(float*))); + } + + cublasDataType_t scale_type = CUDA_R_32F; + cublasCheck(cublasLtMatmulDescSetAttribute(operationDesc, CUBLASLT_MATMUL_DESC_SCALE_TYPE, &scale_type, sizeof(scale_type))); + + // create a preference handle with specified max workspace + cublasLtMatmulPreference_t preference; + cublasCheck(cublasLtMatmulPreferenceCreate(&preference)); + cublasCheck(cublasLtMatmulPreferenceSetAttribute(preference, CUBLASLT_MATMUL_PREF_MAX_WORKSPACE_BYTES, + &cublaslt_workspace_size, sizeof(cublaslt_workspace_size))); + + // find a suitable algorithm (cached internally so shouldn't take much CPU time in practice) + int returnedResults = 0; + cublasLtMatmulHeuristicResult_t heuristic; + cublasLtMatmulAlgoGetHeuristic(cublaslt_handle, operationDesc, ALayout, BLayout, CLayout, DLayout, + preference, 1, &heuristic, &returnedResults); + + if (returnedResults == 0) { + printf("No cuBLASLt FP8 algorithm: m: %d, n: %d, k: %d, bias: %d\n", n, m, k, (bias.data_ptr != NULL)); exit(EXIT_FAILURE); } @@ -226,30 +322,34 @@ void matmul_cublaslt(floatX* d, const floatX* a, const floatX* b, const floatX* cublasCheck(cublasLtMatrixLayoutDestroy(DLayout)); cudaCheck(cudaGetLastError()); } +#endif +template // small wrapper around matmul_cublaslt for the forward pass (keeping historical order of arguments) -void matmul_forward_cublaslt(floatX* out, - floatX* inp, floatX* weight, floatX* bias, - int B, int T, int C, int OC, cudaStream_t stream, - floatX* pre_gelu=NULL, int gelu_fusion=1) { - // By default only fuse GELU for H100+ as cuBLAS seems to be inefficient for fused GELU on Ada/Ampere (?) - if (gelu_fusion < 1 && pre_gelu) { - matmul_cublaslt(pre_gelu, weight, inp, bias, OC, B*T, C, stream, true, false, 0, 0, 0, 0, false, NULL, false); - gelu_forward(out, pre_gelu, B*T*OC, stream); +void matmul_forward(TensorGPU out, + TensorGPU inp, TensorGPU weight, tensorX bias, int BT, int C, int OC, + TensorGPU pre_gelu=TensorGPU(), int gelu_fusion=1, cudaStream_t stream=main_stream) { + if constexpr (sizeof(Tin) == 1) { + matmul_cublaslt_fp8(pre_gelu.enabled() ? pre_gelu : out, weight, inp, bias, OC, BT, C, stream, false, false); + if (pre_gelu.enabled()) { + gelu_forward(out, pre_gelu, stream); + } } else { - matmul_cublaslt(out, weight, inp, bias, OC, B*T, C, stream, true, false, 0, 0, 0, 0, false, pre_gelu, false); + if (pre_gelu.enabled() && gelu_fusion < 1) { + matmul_cublaslt(pre_gelu, weight, inp, bias, OC, BT, C, stream, true, false, 0, 0, 0, 0, false, null_tensorX, false); + gelu_forward(out, pre_gelu, stream); + } else { + matmul_cublaslt(out, weight, inp, bias, OC, BT, C, stream, true, false, 0, 0, 0, 0, false, pre_gelu, false); + } } } -void matmul_backward(floatX* dinp, floatX* dweight, floatX* dbias, - floatX* dout, floatX* inp, floatX* weight, - float* dbias_buffer, - int B, int T, int C, int OC, cudaStream_t stream, - floatX* pre_gelu=NULL, int gelu_fusion=1) { +template +void matmul_backward_bias(tensorX dbias, TensorGPU dout, tensor32 scratch, int BT, int OC, cudaStream_t stream=main_stream) { NVTX_RANGE_FN(); // backward to bias, if given, does a += - if (dbias != NULL) { + if (dbias != null_tensorX) { // Each warp is responsible for 8 * "x128::size" = 64 OCs at BF16 (OC must be a multiple of 64!) // Block size is 1024 | 768 threads (32|24 warps) and we reduce those values into 1 at the end @@ -263,28 +363,89 @@ void matmul_backward(floatX* dinp, floatX* dweight, floatX* dbias, // If we have enough OC that we don't need cross-block reductions, we can skip the bias_buffer accumulation // and write results directly to the output. if(grid_size_y == 1) { - matmul_backward_bias_kernel9<<>>(dbias, dout, B, T, OC, False); + matmul_backward_bias_kernel9<<>>(dbias, dout, BT, OC, False); cudaCheck(cudaGetLastError()); } else { // kernel 9 overwrites temp buffer, so no need to memset - matmul_backward_bias_kernel9<<>>(dbias_buffer, dout, B, T, OC, True); + matmul_backward_bias_kernel9<<>>(scratch, dout, BT, OC, True); cudaCheck(cudaGetLastError()); - reduce_add_sum_kernel<<>>(dbias, dbias_buffer, OC, grid_size_y); + reduce_add_sum_kernel<<>>(dbias, scratch, OC, grid_size_y); cudaCheck(cudaGetLastError()); } - dbias = NULL; // prevent dbias calculation from also being fused in matmul_cublaslt below (if we enabled fusion) } +} + +template +void matmul_backward_fp8(tensor8e5 dinp, tensorX dweight, tensorX dbias, + TensorGPU dout, tensor8 inp, tensor8 weight, + tensor32 scratch1_big, tensor32 scratch2_huge, + int BT, int C, int OC, + tensor8 pre_gelu_activation=tensor8(), cudaStream_t stream=main_stream) { +#ifndef ENABLE_FP8 + // FP8 is not enabled so we use the regular floatX matmul path + matmul_backward(dinp, dweight, dbias, dout, inp, weight, scratch1_big, BT, C, OC, pre_gelu_activation, 1, stream); +#else + NVTX_RANGE_FN(); + matmul_backward_bias(dbias, dout, scratch1_big, BT, OC, stream); + + // N.B.: Both scratch1 and scratch2 are guaranteed to be big enough for 4BTC and 4CC in FP8 + // IMPORTANT: inp is allowed to be the same buffer as scratch2_huge (e.g. for fch_gelu) + // ==> this MUST be done first and write to scratch1_big! + // transpose input + tensor8 inp_fp8_transposed = inp; + inp_fp8_transposed.data_ptr = (float8*)scratch1_big.data_ptr; + transpose_simple(inp_fp8_transposed, inp, C, BT, stream); + + // convert dout to FP8e5 if it is not already, and transpose it + // the buffer is guaranteed to be at least twice as big as 4BTC, so we can split it in 2 + // todo - merge conversion and tranposition like we did before? + tensor8e5 dout_fp8 = *(tensor8e5*)&dout; + if constexpr (std::is_same::value == false) { + dout_fp8.data_ptr = (float8e5*)(scratch2_huge.data_ptr); + copy_advanced(dout_fp8, dout, stream); + } + tensor8e5 dout_fp8_transposed = dout_fp8; + dout_fp8_transposed.data_ptr = (float8e5*)(scratch2_huge.data_ptr + (scratch2_huge.num_elements / 2)); + transpose_simple(dout_fp8_transposed, dout_fp8, OC, BT, stream); + + // GEMM 1: dweight, inp_fp8_transposed, dout_fp8_transposed + matmul_cublaslt_fp8(dweight, inp_fp8_transposed, dout_fp8_transposed, null_tensorX, C, OC, BT, stream, false, true); + + // transpose weight (todo: option to cache this / do it at optimizer time) + tensor8 weight_fp8_transposed = weight; + weight_fp8_transposed.data_ptr = (float8*)scratch1_big.data_ptr; + transpose_simple(weight_fp8_transposed, weight, C, OC, stream); + + // GEMM 2: dinp, weight_fp8_transposed, dout_fp8 + matmul_cublaslt_fp8(dinp, weight_fp8_transposed, dout_fp8, null_tensorX, C, BT, OC, stream, false, true); + + // todo - need dinp and dinp_pre_gelu passed separately here, important for UNIQUE_TENSOR_MEMORY! + // todo - need to support BF16 for dinp into gelu_backwasrd() with FP8 out of gelu_backward()! + if (pre_gelu_activation.enabled()) { + gelu_backward_fp8(dinp, dinp, pre_gelu_activation, stream); + } +#endif +} + + +void matmul_backward(tensorX dinp, tensorX dweight, tensorX dbias, + tensorX dout, tensorX inp, tensorX weight, + tensor32 dbias_scratch, + int BT, int C, int OC, + tensorX pre_gelu_activation=null_tensorX, int gelu_fusion=1, cudaStream_t stream=main_stream) { + NVTX_RANGE_FN(); + matmul_backward_bias(dbias, dout, dbias_scratch, BT, OC, stream); // backward to input, uses = in the backward pass (set the gradient) - matmul_cublaslt(dinp, weight, dout, NULL, C, B*T, OC, stream, false, false, 0, 0, 0, 0, false, - gelu_fusion >= 2 ? pre_gelu : NULL, true); + matmul_cublaslt(dinp, weight, dout, null_tensorX, C, BT, OC, stream, false, false, 0, 0, 0, 0, false, + gelu_fusion >= 2 ? pre_gelu_activation : null_tensorX, true); // backward GELU (if it wasn't fused into the matmul above) - if (gelu_fusion < 2 && pre_gelu) { - gelu_backward_inplace(dinp, pre_gelu, B*T*C, stream); + if ( pre_gelu_activation.enabled() && gelu_fusion < 2) { + gelu_backward(dinp, dinp, pre_gelu_activation, stream); } // backward to weight, uses += in the backward pass (accumulate the gradient) by setting alpha=one - matmul_cublaslt(dweight, inp, dout, NULL /*dbias*/, C, OC, B*T, stream, false, true, 0, 0, 0, 0, - true /* accumulate */, NULL, true); + matmul_cublaslt(dweight, inp, dout, null_tensorX /*dbias*/, C, OC, BT, stream, false, true, 0, 0, 0, 0, + true /* accumulate */, null_tensorX, true); } diff --git a/llmc/tensor.cuh b/llmc/tensor.cuh new file mode 100644 index 000000000..cb996f6a3 --- /dev/null +++ b/llmc/tensor.cuh @@ -0,0 +1,611 @@ +#ifndef TENSOR_CUH +#define TENSOR_CUH + +// ... +//#define FAKE_LOW_PRECISION +#define UNIQUE_TENSOR_MEMORY false +#define LAYERS_PER_ACTIVATION_CHECKPOINT 0 // 0 = disabled +// ... + +#include "cuda_common.h" +#include "cuda_utils.cuh" +#include + +// ---------------------------------------------------------------------------- + +enum TT : uint8_t { + PARAMETER=0, PARAMETER_GRAD, PARAMETER_OPT_M, PARAMETER_OPT_V, PARAMETER_MASTER, // 1 allocation each + MULTIUSE, // single allocation shared for activations, activation gradients, and scratch + DEFAULT, COUNT=DEFAULT, NUM_TYPES_PARAM=PARAMETER_MASTER+1 +}; + +enum TFlags : uint8_t { + NONE=0, + GRADIENT=1, + REUSED_MEMORY=2, + TENSOR_2D=4, // used for matmul weights and activation outputs only (not inputs or gradients) + BIAS=8, + LAYERNORM=16, + RESIDUAL=32, + EMBEDDING=64, + STATS=128 +}; + +// ---------------------------------------------------------------------------- +// forward declarations & extern variables defined in the training file +struct TensorSpec; +constexpr size_t MAX_TENSORS = 32768; // only increases CPU memory usage if unused +constexpr size_t MAX_ABSMAX_HISTORY = 32; // todo - command line option + +extern TensorSpec tensor_specs[MAX_TENSORS]; +extern TensorSpec* tensor_specs_gpu; +extern size_t tensors_start[TT::COUNT]; +extern size_t tensors_bytes[TT::COUNT]; +extern size_t tensors_elements[TT::COUNT]; +extern int num_tensor_specs; + +extern TT current_tensor_type; // todo - avoid having this somehow? +extern int absmax_history_index; // todo - move into model struct? +extern float* gpu_scale_memory; +extern unsigned int* gpu_absmax_memory; +// end element of each tensor to optimise iterating through them in kernels +extern size_t* gpu_tensor_end_element; + +__device__ __constant__ TensorSpec* tensor_specs_ptr; +__device__ __constant__ float* gpu_scale_memory_ptr; +__device__ __constant__ unsigned int* gpu_absmax_memory_ptr; +__device__ __constant__ size_t* tensor_end_element_ptr; + +// ---------------------------------------------------------------------------- +// Helper macros for accessing tensors in the training loop +#define TENSOR(x,layer) get_tensor(x, DEFAULT, layer) +#define ACT_L(x,layer) get_tensor(model->acts.x, MULTIUSE, layer) +#define MULTI_L(x,layer) get_tensor(model->multiuse.x, MULTIUSE, layer) +#define AGRAD_L(x,layer) get_tensor(model->acts_grads.x, MULTIUSE, layer) +#define PARAM_L(x,layer) get_tensor(model->params[PARAMETER].x, PARAMETER, layer) +#define PGRAD_L(x,layer) get_tensor(model->params[PARAMETER_GRAD].x, PARAMETER_GRAD, layer) +#define ACT(x) ACT_L(x,l) +#define MULTI(x) MULTI_L(x,l) +#define AGRAD(x) AGRAD_L(x,l) +#define PARAM(x) PARAM_L(x,l) +#define PGRAD(x) PGRAD_L(x,l) +#define ACT_0(x) ACT_L(x,0) +#define MULTI_0(x) MULTI_L(x,0) + +// ---------------------------------------------------------------------------- + +template +struct TensorGPU { + int id = -1; // TensorSpec index in tensor_specs[] array + ElementType* data_ptr = NULL; + float* scale_descale_ptr = NULL; + unsigned int* absmax_ptr = NULL; + size_t num_elements = 0; + + static __device__ __host__ TensorGPU from(ElementType* ptr=nullptr) { + TensorGPU tmp; + tmp.data_ptr = ptr; + return tmp; + } + template + __device__ __host__ T* as() { + return reinterpret_cast(data_ptr); + } + __device__ __host__ operator ElementType*() const { + return data_ptr; + } + __device__ __host__ ElementType& operator[](size_t index) { + return data_ptr[index]; + } + __device__ __host__ const ElementType& operator[](size_t index) const { + return data_ptr[index]; + } + __device__ __host__ int num_per_128() const { + return sizeof(int4) / sizeof(ElementType); + } + __device__ __host__ bool is_null() const { + return (data_ptr == NULL); + } + __device__ __host__ bool enabled() const { + return (data_ptr != NULL); + } + + static constexpr bool no_scaling = (sizeof(ElementType) != 1); // todo - this prevents scaling FP16 + + __device__ __host__ float get_scalar(size_t index, bool disable_scaling=no_scaling) const { + #ifdef FAKE_LOW_PRECISION + disable_scaling = true; + #endif + ElementType* __restrict__ data_ptr_restricted = data_ptr; + float* __restrict__ scale_ptr_restricted = scale_descale_ptr; + + float value = (float)data_ptr_restricted[index]; + float descale = (scale_descale_ptr && !disable_scaling) ? scale_ptr_restricted[1] : 1.0f; + return value * descale; // [1] = descale + } + + __device__ __host__ ElementType set_scalar(size_t index, float value, bool disable_scaling=no_scaling) { + #ifdef FAKE_LOW_PRECISION + disable_scaling = true; + #endif + ElementType* __restrict__ data_ptr_restricted = data_ptr; + float* __restrict__ scale_ptr_restricted = scale_descale_ptr; + + float scale = (scale_descale_ptr && !disable_scaling) ? scale_ptr_restricted[0] : 1.0f; + ElementType output = (ElementType)(value * scale); + data_ptr_restricted[index] = output; + return output; + } +}; + +typedef TensorGPU tensorX; +typedef TensorGPU tensor32; +typedef TensorGPU tensorFP16; +typedef TensorGPU tensorBF16; +#ifdef ENABLE_FP8 +typedef TensorGPU<__nv_fp8_e4m3> tensor8; +typedef TensorGPU<__nv_fp8_e5m2> tensor8e5; +#else +typedef TensorGPU tensor8; +typedef TensorGPU tensor8e5; +#endif +extern TensorGPU null_tensorX; + +// ---------------------------------------------------------------------------- + +// this is the "foundation" of the other tensor classes (TensorGPU and tensor128) +// they all implicitly refer to this (in tensor_specs[] and tensor_specs_gpu[] for now) with the id +// and these other classes are created by converting from this one (sometimes implicitly) +struct TensorSpec { + int id; + char* ptr; // = model->tensor_memory[tensor_type] + offset + char name[16]; + TT tensor_type; + DType data_type; + short tensor_flags; + + size_t offset; // into tensor type's base pointer + size_t start_element; // on this shard + size_t num_elements; // per shard + short num_shards; + short remaining_layers; + + template + __host__ __device__ operator T*() const { + // todo - sanity check DType matches T + return reinterpret_cast(ptr); + } + + template + __device__ __host__ operator TensorGPU() const { + TensorGPU tensor; + tensor.num_elements = num_elements; + tensor.data_ptr = this->operator T*(); + tensor.id = id; + + #ifdef __CUDA_ARCH__ + tensor.scale_descale_ptr = gpu_scale_memory_ptr + 2*id; + tensor.absmax_ptr = gpu_absmax_memory_ptr + id; + #else + tensor.scale_descale_ptr = gpu_scale_memory + 2*id; + tensor.absmax_ptr = gpu_absmax_memory + id; + #endif + + return tensor; + } +}; + +// ---------------------------------------------------------------------------- + +// debug helper function (enable in get_tensor() for extreme logging) +void print_tensor_elements(int tensor_id) { + TensorSpec spec = tensor_specs[tensor_id]; + size_t num_elements = spec.num_elements; + const char* tensor_name = spec.name; + TT tensor_type = spec.tensor_type; + DType dtype = spec.data_type; + size_t element_size = sizeof_dtype(dtype); + + void* gpu_tensor = spec.ptr; + void* cpu_tensor = malloc(num_elements * element_size); + + // Get scale from GPU + float scale, descale, absmax; + cudaMemcpy(&scale, &gpu_scale_memory[spec.id * 2], sizeof(float), cudaMemcpyDeviceToHost); + cudaMemcpy(&descale, &gpu_scale_memory[spec.id * 2 + 1], sizeof(float), cudaMemcpyDeviceToHost); + cudaMemcpy(&absmax, &gpu_absmax_memory[spec.id], sizeof(float), cudaMemcpyDeviceToHost); + + printf("Printing tensor %s (tensor_type: %d, data_type: %d, flags: %d)\n", tensor_name, (int)tensor_type, (int)dtype, spec.tensor_flags); + printf("GPU memory: %p\n", gpu_tensor); + printf("CPU memory: %p\n", cpu_tensor); + printf("Num elements: %zu\n", num_elements); + printf("Element size: %zu\n", element_size); + printf("Offset: %zu\n", spec.offset); + printf("Scale: %f, Descale: %f, Absmax: %f\n", scale, descale, absmax); + + cudaCheck(cudaMemcpy(cpu_tensor, gpu_tensor, num_elements * element_size, cudaMemcpyDeviceToHost)); + + printf("First 4 & Last 4 of %s:\n", tensor_name); + for (int i = 0; i < 8; i++) { + int idx = (i < 4) ? i : num_elements - 8 + i; + switch (dtype) { + case DType::FP32: printf("%.16f ", ((float*)cpu_tensor)[idx]); break; + case DType::FP16: printf("%.16f ", (float)((__nv_half*)cpu_tensor)[idx]); break; + case DType::BF16: printf("%.16f ", (float)((__nv_bfloat16*)cpu_tensor)[idx]); break; + case DType::FP8E4M3: printf("%.16f ", (float)((__nv_fp8_e4m3*)cpu_tensor)[idx]); break; + case DType::FP8E5M2: printf("%.16f ", (float)((__nv_fp8_e5m2*)cpu_tensor)[idx]); break; + } + if (i == 3) printf("\n"); + } + printf("\n\n"); + + free(cpu_tensor); +} + +// ---------------------------------------------------------------------------- + +TensorSpec get_tensor(int spec_index, TT tensor_type, int layer) { + TensorSpec spec = tensor_specs[spec_index]; + if (layer > 0 && spec.remaining_layers >= layer) { + spec = tensor_specs[spec_index + layer]; + } else if (layer > 0 && spec.remaining_layers > 0) { + printf("ERROR: get_tensor() for %s layer %d but only %d layers remaining\n", spec.name, layer, spec.remaining_layers); + assert(false); + } + assert(spec.tensor_type == tensor_type || tensor_type == DEFAULT); + //print_tensor_elements(spec.id); // enable for extreme debugging + return spec; +} + +// this can only be called at initialisation time, once tensor_specs has been uploaded to the GPU, it is fixed in stone +int add_tensor_spec(const char* name, size_t total_elements, size_t num_shards, DType data_type, int copy_offset_from=-1, int flags=TFlags::NONE, TT tensor_type=TT::DEFAULT) { + assert(num_tensor_specs < MAX_TENSORS); + assert((total_elements % num_shards) == 0); + TensorSpec* spec = &tensor_specs[num_tensor_specs]; + + spec->id = num_tensor_specs; + strncpy(spec->name, name, 15); + spec->name[15] = 0; + spec->tensor_type = (tensor_type == TT::DEFAULT) ? current_tensor_type : tensor_type; + spec->data_type = data_type; + spec->tensor_flags = flags; + + // parameter tensors must fit in a 32-bit unsigned integer (used as an optimisation in e.g. global_norm_tensors_loop) + // todo - either 1) 32-bit everywhere (with a DEFINE?), 2) 64-bit everywhere despite the small performance impact, 3) ? + assert(total_elements < 4UL*1024*1024*1024 || spec->tensor_type == TT::MULTIUSE); + + spec->start_element = tensors_elements[spec->tensor_type]; + spec->num_elements = total_elements / num_shards; + spec->num_shards = num_shards; + spec->remaining_layers = 0; + + if (copy_offset_from >= 0) { + TensorSpec base_spec = tensor_specs[copy_offset_from]; + base_spec.tensor_flags |= (flags & REUSED_MEMORY); + spec->offset = base_spec.offset; + + size_t original_tensor_bytes = base_spec.num_elements * sizeof_dtype(base_spec.data_type); + size_t new_tensor_bytes = spec->num_elements * sizeof_dtype(data_type); + assert(new_tensor_bytes <= original_tensor_bytes); + assert(spec->tensor_type == base_spec.tensor_type); + } else { + spec->offset = tensors_bytes[spec->tensor_type]; + tensors_bytes[spec->tensor_type] += spec->num_elements * sizeof_dtype(data_type); + if (tensors_start[spec->tensor_type] == 0 && spec->tensor_type != 0) { + tensors_start[spec->tensor_type] = num_tensor_specs; + } + } + + tensors_elements[spec->tensor_type] += spec->num_elements; + return num_tensor_specs++; +} + +int add_layer_specs(int num_layers, const char* name, size_t total_elements, size_t num_shards, DType data_type, + int copy_offset_from=-1, int flags=TFlags::NONE, int reuse_every_n_layers=0, + TT tensor_type=TT::DEFAULT) { + int first_tensor_id = num_tensor_specs; + if (reuse_every_n_layers > 0 && num_layers > 1) { + flags |= REUSED_MEMORY; + } + for (int l = 0; l < num_layers; l++) { + char layer_name[16]; + assert(snprintf(layer_name, 15, "%s_%d", name, l) >= 0); + if (reuse_every_n_layers > 0 && l >= reuse_every_n_layers) { + copy_offset_from = first_tensor_id + (l % reuse_every_n_layers); + } + int spec = add_tensor_spec(num_layers > 1 ? layer_name : name, total_elements, num_shards, data_type, copy_offset_from, flags, tensor_type); + tensor_specs[spec].remaining_layers = num_layers - (l + 1); + } + return first_tensor_id; +} + +// ---------------------------------------------------------------------------- + +// the 1st num_tensor_specs values are the absmax of the current/last step +// the next [MAX_ABSMAX_HISTORY * num_tensor_specs] values are the history from previous steps +__global__ void update_scale_descale_kernel(int num_tensor_specs, int absmax_history_index) { + int tid = blockIdx.x * blockDim.x + threadIdx.x; + if (tid >= num_tensor_specs) return; + + // copy current absmax to history then clear it + gpu_absmax_memory_ptr[tid + (absmax_history_index * num_tensor_specs)] = gpu_absmax_memory_ptr[tid]; + gpu_absmax_memory_ptr[tid] = 0; + float absmax = 0.0f; + + // get the maximum absmax from the history (todo - do we want to mitigate outliers here?) + #pragma unroll + for (int i = 1; i <= MAX_ABSMAX_HISTORY; i++) { + absmax = max(absmax, __uint_as_float(gpu_absmax_memory_ptr[tid + (i * num_tensor_specs)])); + } + + // calculate scale based on the maximum absmax + float scale = (absmax != 0.0f) ? (1.0f / absmax) : 1.0f; + + // FP8 e4m3 vs e5m2 (the latter is currently only used for activation gradients) + bool use_e5m2 = (tensor_specs_ptr[tid].data_type == DType::FP8E5M2); + #ifdef FAKE_LOW_PRECISION + if (tensor_specs_ptr[tid].tensor_flags & TFlags::GRADIENT && tensor_specs_ptr[tid].tensor_type == TT::MULTIUSE) { + use_e5m2 = true; + } + #endif + + if (use_e5m2) { + if (absmax != 0.0f) { + scale *= 32768.0f; + } else { + // hacky default to avoid extreme gradient underflow on the 1st step + scale = 4096.0f; + } + } else if (tensor_specs_ptr[tid].data_type == DType::FP8E4M3) { + // todo - power benefit of making sure top bit of exponent is (nearly always) zero? + // this can be done simply by *not* multiplying here, so that the "maximum" is 1.0f + // we probably want some threshold for badly behaved parameters to use the full range + //if (tensor_specs_ptr[tid].tensor_type != TT::PARAMETER || absmax >= 4.0f) { + if (absmax != 0.0f) { + scale *= 256.0f; + } + } + + // update scale and descale memory + // descale must be delayed by one step for parameters (see comment in gpt2_update). + gpu_scale_memory_ptr[tid * 2] = scale; + + if (tensor_specs_ptr[tid].tensor_type == TT::PARAMETER) { + float old_scale = gpu_scale_memory_ptr[tid * 2]; + gpu_scale_memory_ptr[tid * 2 + 1] = 1.0f / old_scale; + } else { + gpu_scale_memory_ptr[tid * 2 + 1] = 1.0f / scale; + } +} + +void update_scales_from_absmax() { + int block_size = 256; + int num_blocks = CEIL_DIV(num_tensor_specs, block_size); + + update_scale_descale_kernel<<>>(num_tensor_specs, absmax_history_index + 1); + absmax_history_index = (absmax_history_index + 1) % MAX_ABSMAX_HISTORY; +} + +// ---------------------------------------------------------------------------- + +template +struct tensor128 { +private: + Packed128 data128; + ElementType* data_ptr; + unsigned int *absmax_ptr = nullptr; + float scale = 1.0f; + float descale = 1.0f; + float new_absmax = 0.0f; + bool wrote_data = false; + bool wrote_absmax = false; + int id = -1; + + // fake fp8 mode (ignored without FAKE_LOW_PRECISION define) + bool faking_low_precision = false; + bool faking_mode_e5 = false; + +public: + bool scaling = (sizeof(ElementType) == 1); + static constexpr const size_t elements = sizeof(int4) / sizeof(ElementType); + __device__ tensor128() { scaling = false; } + + __device__ tensor128(TensorGPU tensor, bool disable_scaling=false) { + data_ptr = tensor.data_ptr; + id = tensor.id; + +#ifdef FAKE_LOW_PRECISION + // fake FP8 only applies to specific tensors to test expected training performance + // todo - expand this to support more unusual formats and test things like blockwise scaling(?) + if (!disable_scaling && id >= 0 && sizeof(ElementType) == 2 && tensor_specs_ptr[id].tensor_type != TT::PARAMETER_GRAD) { + if ((tensor_specs_ptr[id].tensor_flags & (TFlags::RESIDUAL | TFlags::EMBEDDING | TFlags::BIAS)) == 0) { + faking_low_precision = true; + if ((tensor_specs_ptr[id].tensor_flags & TFlags::GRADIENT) && (tensor_specs_ptr[id].tensor_type == TT::MULTIUSE)) { + faking_mode_e5 = true; + } + } + } + scaling = false; // only do "fake" scaling +#endif + + scaling = scaling && !disable_scaling; + if (scaling) { + // using __restrict__ here should allow the compiler to cache/reuse this in loops etc. + const float* __restrict__ ptr_restricted = tensor.scale_descale_ptr; + scale = ptr_restricted[0]; + descale = ptr_restricted[1]; + } + absmax_ptr = tensor.absmax_ptr; + } + + __device__ void load(size_t offset, bool cache_streaming=false) { + ElementType* addr = data_ptr + offset; + data128 = cache_streaming ? load128cs(addr) : load128(addr); + } + + __device__ void store(size_t offset, bool cache_streaming=false) { + if (cache_streaming) store128cs(data_ptr + offset, data128); + else store128(data_ptr + offset, data128); + wrote_data = true; + } + + template + __device__ void store_same_length(size_t offset, bool cache_streaming=false) { + if (cache_streaming) store128_same_length_cs(data_ptr + offset, data128); + else store128_same_length(data_ptr + offset, data128); + wrote_data = true; + } + + __device__ const Packed128& get128() const { return data128; } + __device__ Packed128& get128() { return data128; } + + // call this manually if e.g. you use set_scalar() to update the tensor + // todo - in the future, this could support more than just absmax + __device__ void add_value_stats(float value, ElementType output=(ElementType)0.0f) { + new_absmax = max(new_absmax, fabsf(value)); + } + + // get() and set() automatically apply scaling & descaling for FP8 values + __device__ float get(int index) { + float value = (float)data128[index] * (scaling ? descale : 1.0f); + // used to simulate FP8 and below (just returns the input without FAKE_LOW_PRECISION) + value = fake_low_precision(faking_low_precision, value, scale, descale, faking_mode_e5); + return value; + } + + __device__ void set(int index, float value) { + float output = value * (scaling ? scale : 1.0f); + output = fake_low_precision(faking_low_precision, output, scale, descale, faking_mode_e5); + data128[index] = (ElementType)(output); + add_value_stats(value, data128[index]); + } + + __device__ void set_stochastic(int index, float value, unsigned int random_number, + int rotate_by_index=10, bool non_deterministic_rng=false) { + float scaled_value = value * (scaling ? scale : 1.0f); + + // rotate the random number by the index so we can cheaply reuse the same RNG + // obviously less good than having true per-index RNG, but should be good enough + // when rounding FP32 to FP8, most of the bits make extremely little difference anyway... + // x10 is used so that it never repeats for indices [0;15] with a minimum difference of 2 etc. + if (rotate_by_index) { + assert(index < 16); // >=16 would repeat and be extremely bad RNG + random_number = __funnelshift_l(random_number, random_number, index * rotate_by_index); + } + + // RNG without a seed from the host for quick testing, but obviously not deterministic + // can be forced to get slightly different runs from which you can calculate an average + #ifdef FORCE_NON_DETERMINISM + non_deterministic_rng = true; + #endif + if (non_deterministic_rng) { + unsigned int clock, laneid; + asm volatile("mov.u32 %0, %%clock;" : "=r"(clock)); + asm volatile("mov.u32 %0, %%laneid;" : "=r"(laneid)); + random_number = get_random_noise(clock, laneid, blockIdx.x * blockDim.x); + } + + stochastic_rounding(scaled_value, data128[index], random_number); + add_value_stats(value, data128[index]); + } + + // return value: if true, we can skip __syncthreads() in the calling function as we have just done one + __device__ bool update_absmax(int thread_id, int num_threads, bool exit=false, bool forced=false) { + #ifdef FAKE_LOW_PRECISION + if (absmax_ptr == NULL || !faking_low_precision) { + return false; + } + forced = true; + #endif + + if (!forced && !scaling) { + return false; + } + wrote_absmax = true; + + // lane_id must be obtained directly from the special register + // otherwise, the compiler does silly things related to the redux/atomicMax + unsigned int lane_id ; + asm volatile("mov.u32 %0, %laneid;" : "=r"(lane_id)); + unsigned int num_warps = num_threads / WARP_SIZE; + unsigned int warp_id = thread_id / WARP_SIZE; + + // use native integer reductions as much as possible (supported on all GPUs with FP8) + // this might treat NaN/INF slightly differently but that is the least of our problems + __shared__ unsigned int shared[32]; + unsigned int absmax_uint = *(unsigned int*)&new_absmax; + asm volatile("redux.sync.max.u32 %0, %0, 0xff;" : "+r"(absmax_uint)); + + // with this condition instead of lane_id == 0, we have shared[lane_id] both here and below + // this reduces the number of instructions for addressing + if (lane_id == warp_id) { + shared[lane_id] = absmax_uint; + } + + // sync can be after exit (dead threads don't count) but must be before return + // if this is the end of the kernel, the compiler puts a conditional EXIT right after BAR + // but this way the EXIT is right before the barrier which frees the warps slightly quicker + bool done = (warp_id != 0); + if (done && exit) asm volatile("exit;"); // todo - does this help enough to be worth it? + __syncthreads(); + if (done && !exit) return true; + + // one more warp reduction then global memory atomic + // we want as few global atomics as possible (i.e. 1 per threadblock) + absmax_uint = shared[lane_id]; + if (lane_id >= num_warps) { + absmax_uint = 0; + } + + asm volatile("redux.sync.max.u32 %0, %0, 0xff;" : "+r"(absmax_uint)); + if (lane_id == 0) { + atomicMax(absmax_ptr, absmax_uint); + } + return true; + } + + // helper function to avoid having to specify threadIdx/blockDim manually + __device__ bool update_absmax(int block_dimensions, bool exit=false) { + if (block_dimensions == 1) { + return update_absmax(threadIdx.x, blockDim.x, exit); + } else if (block_dimensions == 2) { + return update_absmax(threadIdx.x + threadIdx.y * blockDim.x, blockDim.x * blockDim.y, exit); + } else if (block_dimensions == 3) { + return update_absmax(threadIdx.x + threadIdx.y * blockDim.x + threadIdx.z * blockDim.x * blockDim.y, + blockDim.x * blockDim.y * blockDim.z, exit); + } + assert(false); + return false; + } + + __device__ void skip_absmax() { + wrote_absmax = true; + } + + __device__ ~tensor128() { + // this should ~always be optimised away by the compiler + if (!wrote_absmax && scaling && wrote_data) { + //printf("id: %d\n", id); + assert(false); + } + } +}; + +template +__device__ tensor128 new_tensor128(TensorGPU tensor, bool disable_scaling=false) { + if constexpr (init) { + return tensor128(tensor, disable_scaling); + } else { + return tensor128(); + } +} + +template +__device__ tensor128 load_tensor128(TensorGPU tensor, size_t offset, + bool cache_streaming = false, bool disable_scaling=false) { + tensor128 t128(tensor, disable_scaling); + t128.load(offset, cache_streaming); + return t128; +} + +#endif // TENSOR_CUH diff --git a/llmc/zero.cuh b/llmc/zero.cuh index e6c5b6e7c..37f8c1b1f 100644 --- a/llmc/zero.cuh +++ b/llmc/zero.cuh @@ -594,4 +594,3 @@ float multi_gpu_cpu_float_sum(float value, MultiGpuConfig* config) { } #endif - diff --git a/profile_gpt2.cu b/profile_gpt2.cu index fa5e528d7..940629af6 100644 --- a/profile_gpt2.cu +++ b/profile_gpt2.cu @@ -58,7 +58,6 @@ int main(int argc, char *argv[]) { model.config.num_layers = 1; set_zero_configs(&multi_gpu_config, 0, model.num_parameters); - gpt2_allocate_state(&model, B, T); // do a training step gpt2_forward(&model, x, B, T); gpt2_backward_and_reduce(&model, x, y, 1, 0); diff --git a/train_gpt2.cu b/train_gpt2.cu index 16f801387..80af71867 100644 --- a/train_gpt2.cu +++ b/train_gpt2.cu @@ -1,3 +1,6 @@ +#define ENABLE_FP8 // todo - makefile option +bool write_as_floatX = true; // todo - make command line option (and test it properly) + /* GPT-2 Transformer Neural Net training loop. See README.md for usage. */ @@ -37,8 +40,10 @@ GPT-2 Transformer Neural Net training loop. See README.md for usage. #include "llmc/cuda_common.h" // defines: // Packed128, f128, x128 -// warpReduceSum, warpReduceMax, blockReduce, copy_and_cast_kernel +// warpReduceSum, warpReduceMax, blockReduce #include "llmc/cuda_utils.cuh" +// todo - document what tensor.cuh implements +#include "llmc/tensor.cuh" // defines: CUBLAS_LOWP, cublasCheck, cublaslt_workspace_size, cublaslt_workspace // defines: cublas_compute, cublaslt_handle, cublas_handle #include "llmc/cublas_common.h" @@ -52,9 +57,11 @@ GPT-2 Transformer Neural Net training loop. See README.md for usage. #ifdef ENABLE_CUDNN // defines: create_cudnn, destroy_cudnn, attention_forward_cudnn, attention_backward_cudnn #include "llmc/cudnn_att.h" +#define CUDNN_ENABLED 1 #else // defines: attention_forward, attention_backward #include "llmc/attention.cuh" +#define CUDNN_ENABLED 0 #endif // defines: fused_classifier #include "llmc/fused_classifier.cuh" @@ -71,19 +78,50 @@ GPT-2 Transformer Neural Net training loop. See README.md for usage. #include "llmc/zero.cuh" // ---------------------------------------------------------------------------- -// global vars for I/O +// global vars regarding the GPU process and disk I/O +cudaDeviceProp deviceProp; // fills in common_start() +cudaStream_t main_stream; char filename_buffer[512]; +constexpr const size_t IO_BUF_SIZE = 32 * 1024 * 1024; // buffer for device <-> disk io // ---------------------------------------------------------------------------- -// global vars containing information about the GPU this process is running on -cudaDeviceProp deviceProp; // fills in common_start() -cudaStream_t main_stream; -// buffer size to use for device <-> disk io -constexpr const size_t IO_BUF_SIZE = 32 * 1024 * 1024; +// global vars for tensors (declared as extern in tensor.cuh to be visible everywhere) +// todo - avoid global variables for this? +TensorSpec tensor_specs[MAX_TENSORS] = {0}; +TensorSpec* tensor_specs_gpu = NULL; +size_t tensors_start[TT::COUNT] = {0}; +size_t tensors_bytes[TT::COUNT] = {0}; +size_t tensors_elements[TT::COUNT] = {0}; +int num_tensor_specs = 0; + +TT current_tensor_type = TT::PARAMETER; +int absmax_history_index = 0; +float* gpu_scale_memory = NULL; +unsigned int* gpu_absmax_memory = NULL; +size_t* gpu_tensor_end_element = NULL; + +tensorX null_tensorX = {0}; // ---------------------------------------------------------------------------- // GPT-2 model definition +typedef struct { + int wte, wpe, lnfw, lnfb; // not per layer + int ln1w, ln1b, qkvw, qkvb, attprojw, attprojb, ln2w, ln2b, fcw, fcb, fcprojw, fcprojb; // per layer +} ParameterTensors; + +typedef struct { + int encoded, lnf, lnf_mean, lnf_rstd, losses, output; // not per layer + int ln1, ln1_mean, ln1_rstd, atty, att, attproj, residual2, ln2, ln2_mean, ln2_rstd, fch, fch_gelu, fcproj, residual3, qkvr; // per layer +} ActivationTensors; + +typedef struct { + int bt4c; // (B, T, 4*C) + int btc; // (B, T, C) + int local_scratch, local_scratch_fp32; // big, see local_scratch_size below + int output_scratch, output_scratch_fp32; // huge, see output_size below +} MultiuseTensors; + typedef struct { int max_seq_len; // max sequence length, e.g. 1024 int vocab_size; // vocab size, e.g. 50257 @@ -93,290 +131,260 @@ typedef struct { int channels; // number of channels, e.g. 768 } GPT2Config; -// the parameters of the model -constexpr const int NUM_PARAMETER_TENSORS = 16; typedef struct { - floatX* wte; // (V, C) - floatX* wpe; // (maxT, C) - floatX* ln1w; // (L, C) - floatX* ln1b; // (L, C) - floatX* qkvw; // (L, 3*C, C) - floatX* qkvb; // (L, 3*C) - floatX* attprojw; // (L, C, C) - floatX* attprojb; // (L, C) - floatX* ln2w; // (L, C) - floatX* ln2b; // (L, C) - floatX* fcw; // (L, 4*C, C) - floatX* fcb; // (L, 4*C) - floatX* fcprojw; // (L, C, 4*C) - floatX* fcprojb; // (L, C) - floatX* lnfw; // (C) - floatX* lnfb; // (C) -} ParameterTensors; -static_assert(sizeof(ParameterTensors) == NUM_PARAMETER_TENSORS * sizeof(void*), "Inconsistent sizes!"); - -void fill_in_parameter_sizes(size_t* param_sizes, size_t* param_sizeof, GPT2Config config) { - size_t Vp = config.padded_vocab_size; - size_t C = config.channels; - size_t maxT = config.max_seq_len; - size_t L = config.num_layers; - param_sizes[0] = Vp * C; // wte - param_sizes[1] = maxT * C; // wpe - param_sizes[2] = L * C; // ln1w - param_sizes[3] = L * C; // ln1b - param_sizes[4] = L * (3 * C) * C; // qkvw - param_sizes[5] = L * (3 * C); // qkvb - param_sizes[6] = L * C * C; // attprojw - param_sizes[7] = L * C; // attprojb - param_sizes[8] = L * C; // ln2w - param_sizes[9] = L * C; // ln2b - param_sizes[10] = L * (4 * C) * C; // fcw - param_sizes[11] = L * (4 * C); // fcb - param_sizes[12] = L * C * (4 * C); // fcprojw - param_sizes[13] = L * C; // fcprojb - param_sizes[14] = C; // lnfw - param_sizes[15] = C; // lnfb - - // populate the parameter sizes in bytes (all the same for now, keeping for future use) - for (int i = 0; i < NUM_PARAMETER_TENSORS; i++) { - param_sizeof[i] = sizeof(floatX); - } -} + GPT2Config config; + ParameterTensors params[NUM_TYPES_PARAM]; + ActivationTensors acts; + ActivationTensors acts_grads; + MultiuseTensors multiuse; -// allocate memory for the parameters and point the individual tensors to the right places -void* malloc_and_point_parameters(ParameterTensors* params, size_t* param_elements, size_t *param_sizeof) { - // calculate the total number of parameters and bytes across all tensors - size_t num_parameters_bytes = 0; - for (int i = 0; i < NUM_PARAMETER_TENSORS; i++) { - num_parameters_bytes += param_elements[i] * param_sizeof[i]; - } - // malloc all parameters all at once on the device - void* params_memory; - cudaCheck(cudaMalloc((void**)¶ms_memory, num_parameters_bytes)); - // assign all the tensors their place in the array - floatX** ptrs[] = { - ¶ms->wte, ¶ms->wpe, ¶ms->ln1w, ¶ms->ln1b, ¶ms->qkvw, ¶ms->qkvb, - ¶ms->attprojw, ¶ms->attprojb, ¶ms->ln2w, ¶ms->ln2b, ¶ms->fcw, ¶ms->fcb, - ¶ms->fcprojw, ¶ms->fcprojb, ¶ms->lnfw, ¶ms->lnfb - }; - char* params_memory_iterator = (char*)params_memory; - for (int i = 0; i < NUM_PARAMETER_TENSORS; i++) { - *(ptrs[i]) = (floatX*)params_memory_iterator; - params_memory_iterator += param_elements[i] * param_sizeof[i]; - } - return params_memory; -} + size_t num_parameters; + size_t num_parameters_bytes; + char* tensor_memory[TT::COUNT] = {0}; -constexpr int NUM_ACTIVATION_TENSORS = 21; -typedef struct { - floatX* encoded; // (B, T, C) - floatX* ln1; // (L, B, T, C) - float* ln1_mean; // (L, B, T) - float* ln1_rstd; // (L, B, T) - floatX* atty; // (L, B, T, C) - // cuDNN saves only some statistics information -#if ENABLE_CUDNN - float* att; // (L, B, NH, T) -#else - floatX* att; // (L, B, NH, T, T) -#endif + // other run state configuration + int batch_size = 0; // the batch size (B) of current forward pass + int seq_len = 0; // the sequence length (T) of current forward pass + int* inputs = NULL; // the input tokens for the current forward pass + int* targets = NULL; // the target tokens for the current forward pass + float mean_loss = -1.0f; // after the last backward micro-batch, will be populated with mean loss across all GPUs and micro-steps + float* accumulated_mean_loss = NULL; // GPU buffer used to accumulate loss across micro-steps + float* cpu_losses = NULL; // CPU buffer to copy the losses to, allocated with cudaMallocHost + int use_master_weights = 1; // keep master weights copy in float for optim update? 0|1 + int gelu_fusion = 0; // fuse gelu via cuBLASLt (0=none, 1=forward, 2=forward+backward) + int recompute = 0; // recompute gelu | layernorm forward during model backward? 0|1|2 + // todo - if other functions need cpu scratch buffers in the future, reuse as generic scratch? + int* workload_indices = NULL; // encoder_backward, B*T*num_c_groups (int) + int4* bucket_info = NULL; // encoder_backward, B*T*num_c_groups (int4) - size for worst case - floatX* residual2; // (L, B, T, C) - floatX* ln2; // (L, B, T, C) - float* ln2_mean; // (L, B, T) - float* ln2_rstd; // (L, B, T) - floatX* fch; // (L, B, T, 4*C) - floatX* fch_gelu; // (L, B, T, 4*C) - floatX* residual3; // (L, B, T, C) - floatX* lnf; // (B, T, C); if LN recomputation is enabled (-r 2 and above), will be used for _all_ layernorms - float* lnf_mean; // (B, T) - float* lnf_rstd; // (B, T) - float* losses; // (B, T), will be accumulated in micro-steps - // adding these two compared to the CPU .c code, needed for attention kernel as buffers - floatX* qkvr; // (L, B, T, 3*C) - // in inference mode, this buffer will store the logits - // in training mode, this buffer will contain the *gradients* of the logits. - // during the processing of transformer blocks, we will also use this as a - // general scratchpad buffer. Allocation is made large enough to hold (B, T, 3C), - // (B, NH, T, T), and (B, T, V) shaped tensors. - floatX* output; - - // some additional scratch buffers - floatX* scratch_bt4c; // (B, T, 4*C) - floatX* scratch_btc; // (B, T, C) -} ActivationTensors; + unsigned long long rng_state; // the RNG state for seeding stochastic rounding etc. + unsigned long long rng_state_last_update; // RNG before last gpt2_update() to re-round identically from master weights +} GPT2; +#define TENSOR_SPECS(name, layers, dim, flags) spec->name = add_layer_specs(layers, #name, dim, shards, dtype, -1, flags, reuse_every_n) +#define TENSOR_SPECS_LOWP(name, layers, dim, flags) spec->name = add_layer_specs(layers, #name, dim, shards, dtype_lowp, -1, flags, reuse_every_n) +#define TENSOR_SPECS_FP32(name, layers, dim, flags) spec->name = add_layer_specs(layers, #name, dim, shards, DType::FP32, -1, flags, reuse_every_n) -struct TensorSpec { - void** ptr; - size_t size; - DType type; -}; +void gpt2_allocate(GPT2 *model) { + size_t Vp = model->config.padded_vocab_size; + size_t C = model->config.channels; + size_t maxT = model->config.max_seq_len; + size_t L = model->config.num_layers; + size_t B = model->batch_size; + size_t T = model->seq_len; + size_t NH = model->config.num_heads; + size_t BTC = B*T*C; + size_t BT = B*T; + + // output is also used as a scratch buffer (floatX), needs to be big enough for: + // 1) Output: B*T*Vp (padded vocabulary size) + // 2) 4BTC (largest activation/grad tensor) + // 3) 4CC FP8 (largest parameter tensor, 2*C*C if floatX=BF16) + // 4) B*T*T*NH (non-cuDNN attention tensor) + size_t output_size = max(B*T * max(Vp, 4*C), 4*C*C/sizeof(floatX)); + output_size = CUDNN_ENABLED ? output_size : max(output_size, B*T*T*NH); + // local scratch (floatX), must be big enough for: + // 1) BTC (in floatX) + // 2) 4BTC FP8 (transpose cache) + // 2) 4CC FP8 (largest parameter tensor in FP8) + // 3) 4BTC BF16 (non-cuDNN backwards scratch in floatX) + size_t local_scratch_size = max(CUDNN_ENABLED ? 4*BTC/sizeof(floatX) : 4*BTC, 4*C*C/sizeof(floatX)); + + int reuse_every_n = 0; + int shards = 1; + int shards_opt = (multi_gpu_config.zero_stage >= 1) ? multi_gpu_config.num_processes : 1; + int shards_grad = (multi_gpu_config.zero_stage >= 2) ? multi_gpu_config.num_processes : 1; + + // 1) parameters & optimizer state + for (int t = PARAMETER; t <= PARAMETER_MASTER; t++) { + if (t == PARAMETER_MASTER && !model->use_master_weights) continue; + + current_tensor_type = (TT)t; + ParameterTensors* spec = &model->params[t]; + shards = (t == PARAMETER) ? 1 : (t == PARAMETER_GRAD) ? shards_grad : shards_opt; + + DType dtype = (t <= PARAMETER_GRAD) ? DTYPE_FLOATX : DType::FP32; + DType dtype_lowp = (t == PARAMETER) ? DTYPE_FP8E4 : ((t == PARAMETER_GRAD) ? DTYPE_FLOATX : DType::FP32); + + TENSOR_SPECS (wte, 1, Vp * C, TENSOR_2D | EMBEDDING); + TENSOR_SPECS (wpe, 1, maxT * C, TENSOR_2D | EMBEDDING); + TENSOR_SPECS (ln1w, L, C, LAYERNORM); + TENSOR_SPECS (ln1b, L, C, LAYERNORM | BIAS); + TENSOR_SPECS_LOWP(qkvw, L, 3 * C * C, TENSOR_2D); + TENSOR_SPECS (qkvb, L, 3 * C, BIAS); + TENSOR_SPECS (attprojw, L, C * C, TENSOR_2D); + TENSOR_SPECS (attprojb, L, C, BIAS); + TENSOR_SPECS (ln2w, L, C, LAYERNORM); + TENSOR_SPECS (ln2b, L, C, LAYERNORM | BIAS); + TENSOR_SPECS_LOWP(fcw, L, 4 * C * C, TENSOR_2D); + TENSOR_SPECS (fcb, L, 4 * C, BIAS); + TENSOR_SPECS_LOWP(fcprojw, L, 4 * C * C, TENSOR_2D); + TENSOR_SPECS (fcprojb, L, C, BIAS); + TENSOR_SPECS (lnfw, 1, C, LAYERNORM); + TENSOR_SPECS (lnfb, 1, C, LAYERNORM | BIAS); + } + model->num_parameters_bytes = tensors_bytes[TT::PARAMETER]; + model->num_parameters = tensors_elements[TT::PARAMETER]; -#define TENSOR_SPEC(pointer, size) TensorSpec{(void**)(&pointer), (size), dtype_of(pointer)}; + // 2) multiuse & scratch tensors + current_tensor_type = MULTIUSE; + model->multiuse.bt4c = model->multiuse.btc = -1; + if (UNIQUE_TENSOR_MEMORY == false) { + model->multiuse.bt4c = add_tensor_spec("multiuse_bt4c", 4 * BTC, 1, DTYPE_FLOATX, -1, REUSED_MEMORY); + model->multiuse.btc = add_tensor_spec("multiuse_btc", BTC, 1, DTYPE_FLOATX, -1, REUSED_MEMORY); + } + model->multiuse.local_scratch = add_tensor_spec("scratch_X", local_scratch_size, 1, DTYPE_FLOATX, -1, REUSED_MEMORY); + model->multiuse.output_scratch = add_tensor_spec("out_scratch_X", output_size, 1, DTYPE_FLOATX, -1, REUSED_MEMORY); + + model->multiuse.local_scratch_fp32 = add_tensor_spec("scratch_32", local_scratch_size / sizeof(floatX), 1, DType::FP32, model->multiuse.local_scratch, REUSED_MEMORY); + model->multiuse.output_scratch_fp32 = add_tensor_spec("out_scratch_32", output_size / sizeof(floatX), 1, DType::FP32, model->multiuse.output_scratch, REUSED_MEMORY); + + // 3) activations + ActivationTensors* spec = &model->acts; + DType dtype_lowp = DTYPE_FP8E4; + DType dtype = DTYPE_FLOATX; + shards = 1; + + // with activation checkpointing, we keep every layer's residual3 for simplicity + // in theory, with e.g. 4 layers per checkpoint, we'd have 1/4 as many residual3 + // but that would complicate everything a lot for relatively little benefit... + TENSOR_SPECS (residual3, L, BTC, RESIDUAL); + reuse_every_n = LAYERS_PER_ACTIVATION_CHECKPOINT; + assert(!reuse_every_n || !(L % reuse_every_n)); + + TENSOR_SPECS (encoded, 1, BTC, EMBEDDING); + TENSOR_SPECS (qkvr, L, 3 * BTC, TENSOR_2D); + TENSOR_SPECS (atty, L, BTC, 0); + TENSOR_SPECS (residual2, L, BTC, RESIDUAL); + TENSOR_SPECS_LOWP(fch, L, 4 * BTC, TENSOR_2D); -void fill_in_activation_sizes(const ActivationTensors* data, TensorSpec (&tensors)[NUM_ACTIVATION_TENSORS], size_t B, size_t T, GPT2Config config, int recompute) { - size_t Vp = config.padded_vocab_size; - size_t L = config.num_layers; - size_t NH = config.num_heads; - size_t C = config.channels; - tensors[0] = TENSOR_SPEC(data->encoded, B * T * C); - // if recompute >= 1 then we will recompute the layernorm forward activation during backward pass - tensors[1] = TENSOR_SPEC(data->ln1, (recompute < 2) ? L * B * T * C : 0); - tensors[2] = TENSOR_SPEC(data->ln1_mean, L * B * T); - tensors[3] = TENSOR_SPEC(data->ln1_rstd, L * B * T); - tensors[4] = TENSOR_SPEC(data->atty, L * B * T * C); #ifdef ENABLE_CUDNN - // FP32 stats tensor for cuDNN to be passed to backward pass - tensors[5] = TENSOR_SPEC(data->att, L * B * NH * T); + TENSOR_SPECS_FP32(att, L, NH * B * T, 0); #else - tensors[5] = TENSOR_SPEC(data->att, L * B * NH * T * T); + TENSOR_SPECS (att, L, NH * B * T * T, 0); #endif - tensors[6] = TENSOR_SPEC(data->residual2, L * B * T * C); - // if recompute >= 1 then we will recompute the layernorm forward activation during backward pass - tensors[7] = TENSOR_SPEC(data->ln2, (recompute < 2) ? L * B * T * C : 0); - tensors[8] = TENSOR_SPEC(data->ln2_mean, L * B * T); - tensors[9] = TENSOR_SPEC(data->ln2_rstd, L * B * T); - tensors[10] = TENSOR_SPEC(data->fch, L * B * T * 4*C); - // if recompute >= 1 then we will recompute gelu_forward during backward and use this as scratch buffer - tensors[11] = TENSOR_SPEC(data->fch_gelu, (recompute < 1) ? L * B * T * 4*C : B * T * 4*C); - tensors[12] = TENSOR_SPEC(data->residual3, L * B * T * C); - tensors[13] = TENSOR_SPEC(data->lnf, B * T * C); - tensors[14] = TENSOR_SPEC(data->lnf_mean, B * T); - tensors[15] = TENSOR_SPEC(data->lnf_rstd, B * T); - tensors[16] = TENSOR_SPEC(data->losses, B * T); - tensors[17] = TENSOR_SPEC(data->qkvr, L * B * T * 3*C); - tensors[18] = TENSOR_SPEC(data->output, B * T * max(3*C, max(NH*T, Vp))); - - tensors[19] = TENSOR_SPEC(data->scratch_bt4c, B * T * 4 * C); - tensors[20] = TENSOR_SPEC(data->scratch_btc, B * T * C); -} -void* malloc_and_point_activations(TensorSpec (&tensors)[NUM_ACTIVATION_TENSORS]) { - size_t bytes = 0; - for (size_t i = 0; i < NUM_ACTIVATION_TENSORS; i++) { - bytes += tensors[i].size * sizeof_dtype(tensors[i].type); + // optionally reuse the same activation buffer at each layer and re-compute the gelu during backward + // very useful because we dramatically reduce VRAM usage, and may be able to fit larger batch size + if (model->recompute < 1 || UNIQUE_TENSOR_MEMORY) { + TENSOR_SPECS_LOWP(fch_gelu, L, 4 * BTC, 0); + TENSOR_SPECS_LOWP(ln1, L, BTC, LAYERNORM); + TENSOR_SPECS_LOWP(ln2, L, BTC, LAYERNORM); + } else if (model->recompute < 2) { + spec->fch_gelu = add_layer_specs(L, "fch_gelu", 4 * BTC, shards, dtype_lowp, model->multiuse.output_scratch, REUSED_MEMORY); + TENSOR_SPECS_LOWP(ln1, L, BTC, LAYERNORM); + TENSOR_SPECS_LOWP(ln2, L, BTC, LAYERNORM); + } else { + spec->fch_gelu = add_layer_specs(L, "fch_gelu", 4 * BTC, shards, dtype_lowp, model->multiuse.output_scratch, REUSED_MEMORY); + spec->ln1 = add_layer_specs(L, "ln1", BTC, shards, dtype_lowp, model->multiuse.btc, LAYERNORM | REUSED_MEMORY); + spec->ln2 = add_layer_specs(L, "ln2", BTC, shards, dtype_lowp, model->multiuse.btc, LAYERNORM | REUSED_MEMORY); } - printf0("allocating %d MiB for activations\n", (int)round(bytes / (1024 * 1024))); - - void* acts_memory; - cudaCheck(cudaMalloc((void**)&acts_memory, bytes)); - - // cudaMalloc does not guarantee initial memory values so we memset the allocation here - // this matters because e.g. non-cuDNN attention assumes the attention buffer is zeroed - // todo - up to ~100ms on slow GPUs, could theoretically be more selective, but this is safer - cudaCheck(cudaMemset(acts_memory, 0, bytes)); - - char* acts_memory_iterator = (char*)acts_memory; - for (size_t i = 0; i < NUM_ACTIVATION_TENSORS; i++) { - // extra protection so we don't accidentally use an empty buffer - if(tensors[i].size == 0) { - *(tensors[i].ptr) = NULL; - }else { - *(tensors[i].ptr) = acts_memory_iterator; - acts_memory_iterator += tensors[i].size * sizeof_dtype(tensors[i].type); - } + if (UNIQUE_TENSOR_MEMORY) { + TENSOR_SPECS (attproj, L, BTC, TENSOR_2D); + TENSOR_SPECS_LOWP(fcproj, L, BTC, TENSOR_2D); + TENSOR_SPECS (output, 1, output_size, TENSOR_2D | EMBEDDING); + } else { + spec->attproj = add_layer_specs(L, "attproj", BTC, shards, dtype, model->multiuse.btc, REUSED_MEMORY | TENSOR_2D); + spec->fcproj = add_layer_specs(L, "fcproj", BTC, shards, dtype_lowp, model->multiuse.btc, REUSED_MEMORY | TENSOR_2D); + spec->output = add_tensor_spec("output", output_size, shards, dtype, model->multiuse.output_scratch, REUSED_MEMORY | EMBEDDING | TENSOR_2D); } - return acts_memory; -} - -typedef struct { - GPT2Config config; - // the weights of the model, and their sizes - ParameterTensors params; - size_t param_elements[NUM_PARAMETER_TENSORS]; - size_t param_sizeof[NUM_PARAMETER_TENSORS]; - void* params_memory; - size_t num_parameters; - size_t num_parameters_bytes; - // gradients of the weights - ParameterTensors grads; - void* grads_memory; - // buffers for the AdamW optimizer - float* m_memory; - float* v_memory; - float* master_weights; // is NULL unless fp32 weights is enabled. - // the activations of the model, and their sizes - ActivationTensors acts; - TensorSpec acts_specs[NUM_ACTIVATION_TENSORS]; - void* acts_memory; - // other run state configuration - int batch_size; // the batch size (B) of current forward pass - int seq_len; // the sequence length (T) of current forward pass - int* inputs; // the input tokens for the current forward pass - int* targets; // the target tokens for the current forward pass - float mean_loss; // after the last backward micro-batch, will be populated with mean loss across all GPUs and micro-steps - float* accumulated_mean_loss; // GPU buffer used to accumulate loss across micro-steps - float* cpu_losses; // CPU buffer to copy the losses to, allocated with cudaMallocHost - unsigned long long rng_state; // the RNG state for seeding stochastic rounding etc. - unsigned long long rng_state_last_update; // RNG before last gpt2_update() to re-round identically from master weights - int use_master_weights; // keep master weights copy in float for optim update? 0|1 - bool init_state; // set to true if master weights need to be initialized - int gelu_fusion; // fuse gelu via cuBLASLt (0=none, 1=forward, 2=forward+backward) - int recompute; // recompute gelu | layernorm forward during model backward? 0|1|2 - // todo - if other functions need cpu scratch buffers in the future, reuse as generic scratch? - int* workload_indices; // encoder_backward, B*T*num_c_groups (int) - int4* bucket_info; // encoder_backward, B*T*num_c_groups (int4) - size for worst case -} GPT2; - -void gpt2_init_common(GPT2 *model) { - // common inits outside of the model weights - // memory lazily initialized in forward() - model->acts_memory = NULL; - model->inputs = NULL; - model->targets = NULL; - model->accumulated_mean_loss = NULL; - model->cpu_losses = NULL; - // the B,T params are determined and set, fixed on first batch in forward() - model->batch_size = 0; - model->seq_len = 0; - model->mean_loss = -1.0f; // -1.0f designates no loss, set at end of forward() - model->params_memory = NULL; - // memory lazily initialized in backward() - model->grads_memory = NULL; - model->workload_indices = NULL; // on cpu, for encoder_backward - model->bucket_info = NULL; // on cpu, for encoder_backward - // memory lazily initialized in update() - model->m_memory = NULL; - model->v_memory = NULL; - model->master_weights = NULL; - // other default settings - model->rng_state = 13371337 + multi_gpu_config.process_rank; // used in stochastic rounding - model->use_master_weights = 1; // safe default: do keep master weights in fp32 - model->init_state = true; - model->recompute = 1; // good default: recompute gelu but not layernorm - model->gelu_fusion = 0; //deviceProp.major >= 9 ? 2 : 0; // default: off for now (default must match main()) -} -void gpt2_allocate_weights(GPT2 *model) { - // fill in all the parameter tensor dimensions and types - fill_in_parameter_sizes(model->param_elements, model->param_sizeof, model->config); - model->num_parameters = 0; - model->num_parameters_bytes = 0; - for (int i = 0; i < NUM_PARAMETER_TENSORS; i++) { - model->num_parameters += model->param_elements[i]; - model->num_parameters_bytes += model->param_elements[i] * model->param_sizeof[i]; + TENSOR_SPECS (lnf, 1, BTC, LAYERNORM); + TENSOR_SPECS_FP32(losses, 1, BT, 0); + TENSOR_SPECS_FP32(ln1_mean, L, BT, LAYERNORM | STATS); + TENSOR_SPECS_FP32(ln1_rstd, L, BT, LAYERNORM | STATS); + TENSOR_SPECS_FP32(ln2_mean, L, BT, LAYERNORM | STATS); + TENSOR_SPECS_FP32(ln2_rstd, L, BT, LAYERNORM | STATS); + TENSOR_SPECS_FP32(lnf_mean, 1, BT, LAYERNORM | STATS); + TENSOR_SPECS_FP32(lnf_rstd, 1, BT, LAYERNORM | STATS); + + // 4) activation gradients + // note: TENSOR_2D are for the tensors written to by a matmul which are different here + // todo - is "LAYERNORM" applied logically here? do we care? + reuse_every_n = 0; + spec = &model->acts_grads; + dtype_lowp = DTYPE_FP8E5; + shards = 1; + + if (UNIQUE_TENSOR_MEMORY) { + TENSOR_SPECS (encoded, 1, BTC, GRADIENT | EMBEDDING); + TENSOR_SPECS (output, 1, output_size, GRADIENT | EMBEDDING); + TENSOR_SPECS (lnf, 1, BTC, GRADIENT | LAYERNORM); + TENSOR_SPECS_LOWP(ln1, L, BTC, GRADIENT | LAYERNORM); + TENSOR_SPECS (atty, L, BTC, GRADIENT); + TENSOR_SPECS (residual2, L, BTC, GRADIENT | RESIDUAL); + TENSOR_SPECS_LOWP(ln2, L, BTC, GRADIENT | LAYERNORM); + TENSOR_SPECS_LOWP(fch, L, 4 * BTC, GRADIENT); + TENSOR_SPECS_LOWP(fch_gelu, L, 4 * BTC, GRADIENT); + TENSOR_SPECS (residual3, L, BTC, GRADIENT | RESIDUAL); + TENSOR_SPECS (qkvr, L, 3 * BTC, GRADIENT); + } else { + spec->output = add_layer_specs(1, "output", output_size, 1, dtype, model->multiuse.output_scratch, GRADIENT | EMBEDDING); + + int reused_btc = model->acts.residual3 + (L-1); // todo - check if this works with activation checkpointing + spec->ln1 = add_layer_specs(L, "ln1", BTC, 1, dtype_lowp, reused_btc, GRADIENT | LAYERNORM); + spec->ln2 = add_layer_specs(L, "ln2", BTC, 1, dtype_lowp, reused_btc, GRADIENT | LAYERNORM); + spec->atty = add_layer_specs(L, "atty", BTC, 1, dtype, reused_btc, GRADIENT); + + int reused_btc2 = model->acts.lnf; + spec->residual2 = add_layer_specs(L, "residual2", BTC, 1, dtype, reused_btc2, GRADIENT | RESIDUAL); + spec->residual3 = add_layer_specs(L, "residual3", BTC, 1, dtype, reused_btc2, GRADIENT | RESIDUAL); + spec->encoded = add_layer_specs(1, "encoded", BTC, 1, dtype, reused_btc2, GRADIENT | EMBEDDING); + + // (lnf doesn't need bt4c but it's free at this point unlike the other buffers) + spec->lnf = add_layer_specs(1, "lnf", BTC, 1, dtype, model->multiuse.bt4c, GRADIENT | LAYERNORM); + spec->fch_gelu = add_layer_specs(L, "fch_gelu", 4 * BTC, 1, dtype_lowp, model->multiuse.bt4c, GRADIENT); + spec->fch = add_layer_specs(L, "fch", 4 * BTC, 1, dtype_lowp, model->multiuse.bt4c, GRADIENT); + spec->qkvr = add_layer_specs(L, "qkvr", 3 * BTC, 1, dtype, model->multiuse.bt4c, GRADIENT); } - // create memory for model parameters on the device - assert(model->params_memory == nullptr); - model->params_memory = malloc_and_point_parameters(&model->params, model->param_elements, model->param_sizeof); -} -void gpt2_allocate_state(GPT2 *model, int B, int T) { - printf0("allocating %d MiB for parameter gradients\n", (int)round(model->num_parameters * sizeof(floatX) / (1024 * 1024))); - assert(model->grads_memory == nullptr); - model->grads_memory = malloc_and_point_parameters(&model->grads, model->param_elements, model->param_sizeof); + // allocate a single huge GPU buffer for all the tensors of a given type + for (int i = 0; i < TT::COUNT; i++) { + if (i == PARAMETER_MASTER && !model->use_master_weights) continue; + cudaCheck(cudaMalloc(&model->tensor_memory[i], tensors_bytes[i])); + } - // record the current B,T as well - model->batch_size = B; - model->seq_len = T; + // Set the GPU pointer for each tensor spec (so we don't need to know the base and the offset) + // also specify the end elements explicitly to optimise kernels iterating over the tensors + size_t* cpu_tensor_end_element = (size_t*)mallocCheck(sizeof(size_t) * num_tensor_specs + 256); + for (size_t i = 0; i < num_tensor_specs; i++) { + TensorSpec* spec = &tensor_specs[i]; + spec->ptr = model->tensor_memory[spec->tensor_type] + spec->offset; + cpu_tensor_end_element[i] = spec->start_element + spec->num_elements; + } - // allocate the space - fill_in_activation_sizes(&model->acts, model->acts_specs, B, T, model->config, model->recompute); - model->acts_memory = malloc_and_point_activations(model->acts_specs); - // also create memory for caching inputs and targets + // we are finished creating the tensors specs and copy them to the GPU (they are effectively read-only) + cudaMalloc((void**)&tensor_specs_gpu, sizeof(TensorSpec) * num_tensor_specs); + cudaMemcpy(tensor_specs_gpu, tensor_specs, sizeof(TensorSpec) * num_tensor_specs, cudaMemcpyHostToDevice); + // also upload the "end element" array which we use to optimise iterating through tensors in our kernels + // extra 256B so that we can avoid bounds checking when prefetching etc. + cudaMalloc(&gpu_tensor_end_element, sizeof(size_t) * num_tensor_specs + 256); + cudaMemcpy(gpu_tensor_end_element, cpu_tensor_end_element, sizeof(size_t) * num_tensor_specs + 256, cudaMemcpyHostToDevice); + free(cpu_tensor_end_element); + + // todo - move this elsewhere so it's not in the middle of the parameter table... + printf("number of parameter bytes: %zu MiB\n", tensors_bytes[TT::PARAMETER] / (1024*1024)); + printf("number of parameter gradient bytes: %zu MiB\n", tensors_bytes[TT::PARAMETER_GRAD] / (1024*1024)); + printf("number of m bytes: %zu MiB\n", tensors_bytes[TT::PARAMETER_OPT_M] / (1024*1024)); + printf("number of v bytes: %zu MiB\n", tensors_bytes[TT::PARAMETER_OPT_V] / (1024*1024)); + printf("number of master weight bytes: %zu MiB\n", tensors_bytes[TT::PARAMETER_MASTER] / (1024*1024)); + printf("number of act+actgrad+multiuse bytes: %zu MiB\n", tensors_bytes[TT::MULTIUSE] / (1024*1024)); + + // absmax/scale/descale buffers for FP8 & Friends (scale is initialised via update_scales_from_absmax) + cudaMalloc(&gpu_scale_memory, 2 * num_tensor_specs * sizeof(float)); + cudaMalloc(&gpu_absmax_memory, sizeof(unsigned int) * num_tensor_specs * (MAX_ABSMAX_HISTORY + 1)); + cudaMemset(gpu_absmax_memory, 0, sizeof(unsigned int) * num_tensor_specs * (MAX_ABSMAX_HISTORY + 1)); + + // copy pointers to constant buffers for easy access on the GPU + cudaMemcpyToSymbol(tensor_specs_ptr, &tensor_specs_gpu, sizeof(TensorSpec*)); + cudaMemcpyToSymbol(gpu_scale_memory_ptr, &gpu_scale_memory, sizeof(float*)); + cudaMemcpyToSymbol(gpu_absmax_memory_ptr, &gpu_absmax_memory, sizeof(unsigned int*)); + cudaMemcpyToSymbol(tensor_end_element_ptr, &gpu_tensor_end_element, sizeof(size_t*)); + // ======================= + // allocate_state stuff + // ======================= cudaCheck(cudaMalloc((void**)&model->inputs, B * T * sizeof(int))); cudaCheck(cudaMalloc((void**)&model->targets, B * T * sizeof(int))); cudaCheck(cudaMalloc(((void**)&model->accumulated_mean_loss), sizeof(float))); @@ -388,31 +396,103 @@ void gpt2_allocate_state(GPT2 *model, int B, int T) { model->workload_indices = (int*)mallocCheck(sizeof(int) * model->batch_size * model->seq_len * num_c_groups); model->bucket_info = (int4*)mallocCheck(sizeof(int4) * model->batch_size * model->seq_len * num_c_groups); - size_t shard_num_parameters = multi_gpu_config.shard_num_parameters; // num parameters we are responsible for - printf0("allocating %zu MiB for AdamW optimizer state m\n", (shard_num_parameters * sizeof(float)) >> 20); - printf0("allocating %zu MiB for AdamW optimizer state v\n", (shard_num_parameters * sizeof(float)) >> 20); - assert(model->m_memory == nullptr); - assert(model->v_memory == nullptr); - cudaCheck(cudaMalloc((void**)&model->m_memory, shard_num_parameters * sizeof(float))); - cudaCheck(cudaMalloc((void**)&model->v_memory, shard_num_parameters * sizeof(float))); - - if (model->use_master_weights == 1) { - assert(model->master_weights == nullptr); - printf0("allocating %zu MiB for master copy of params\n", (shard_num_parameters * sizeof(float)) >> 20); - cudaCheck(cudaMalloc((void**) &model->master_weights, shard_num_parameters * sizeof(float))); - } - + // check available memory and give an estimate of the maximum batch size size_t free, total; cudaCheck(cudaMemGetInfo(&free, &total)); printf0("device memory usage: %zd MiB / %zd MiB\n", (total-free) / 1024 / 1024, total / 1024 / 1024); - // give an estimate of the maximum batch size - size_t bytes_per_sequence = 0; - for (size_t i = 0; i < NUM_ACTIVATION_TENSORS; i++) { - bytes_per_sequence += model->acts_specs[i].size * sizeof_dtype(model->acts_specs[i].type) / B; - } + size_t bytes_per_sequence = tensors_bytes[TT::MULTIUSE] / B; // pessimistic (output buffer etc.) printf0("memory per sequence: %zu MiB\n", bytes_per_sequence / 1024 / 1024); printf0(" -> estimated maximum batch size: %zu\n", B + free / bytes_per_sequence); + cudaCheck(cudaGetLastError()); +} + +void gpt2_init_common(GPT2 *model) { + // other default settings + model->rng_state = 13371337 + multi_gpu_config.process_rank; // used in stochastic rounding +} + +// take a GPU buffer with "num_parameters * sizeof(floatX)" in the required order +// and convert each individual tensor to its desired data type +template +void convert_fixed_parameters(GPT2* model, char* gpu_buffer, size_t fixed_size_bytes) { + size_t offset = 0; + int num_param_tensors = tensors_start[PARAMETER+1]; + + for (int i = 0; i < num_param_tensors; i++) { + TensorGPU tensor = tensor_specs[i]; + tensor.data_ptr = (Tin*)(gpu_buffer + offset); + offset += tensor.num_elements * sizeof(Tin); + update_absmax(tensor); + } + update_scales_from_absmax(); + + offset = 0; + for (int i = 0; i < num_param_tensors; i++) { + TensorGPU tensor_in = tensor_specs[i]; + tensor_in.data_ptr = (Tin*)(gpu_buffer + offset); + offset += tensor_in.num_elements * sizeof(Tin); + + switch (tensor_specs[i].data_type) { + case DType::FP32: copy_advanced((TensorGPU)tensor_specs[i], tensor_in); break; + case DType::FP16: copy_advanced((TensorGPU<__nv_half>)tensor_specs[i], tensor_in); break; + case DType::BF16: copy_advanced((TensorGPU<__nv_bfloat16>)tensor_specs[i], tensor_in); break; + case DType::FP8E4M3: copy_advanced((TensorGPU<__nv_fp8_e4m3>)tensor_specs[i], tensor_in); break; + } + if (model->use_master_weights) { + size_t master_start = tensors_start[PARAMETER_MASTER]; + TensorGPU master = tensor_specs[i+master_start]; + size_t shard_offset = master.num_elements * (multi_gpu_config.process_rank % tensor_specs[i+master_start].num_shards); + + tensor_in.data_ptr += shard_offset; + copy_advanced(master, tensor_in); + } + } + cudaMemset(gpu_buffer, 0, fixed_size_bytes); +} + +// convert from variable precision parameters to a single precision (e.g. before checkpointing) +// todo +template +void convert_to_fixed_parameters(GPT2* model, char* gpu_buffer) { + size_t offset = 0; + for (int i = 0; i < tensors_start[PARAMETER+1]; i++) { + TensorGPU tensor_out = tensor_specs[i]; + tensor_out.data_ptr = (Tout*)(gpu_buffer + offset); + offset += tensor_out.num_elements * sizeof(Tout); + + switch (tensor_specs[i].data_type) { + case DType::FP32: copy_advanced(tensor_out, (TensorGPU)tensor_specs[i]); break; + case DType::FP16: copy_advanced(tensor_out, (TensorGPU<__nv_half>)tensor_specs[i]); break; + case DType::BF16: copy_advanced(tensor_out, (TensorGPU<__nv_bfloat16>)tensor_specs[i]); break; + case DType::FP8E4M3: copy_advanced(tensor_out, (TensorGPU<__nv_fp8_e4m3>)tensor_specs[i]); break; + } + } +} + +// helper function to initialise sharded master weights from unsharded weights +template +void init_tensor_shard(TensorGPU out, int i) { + size_t shard_offset = out.num_elements * (multi_gpu_config.process_rank % tensor_specs[out.id].num_shards); + TensorGPU t = tensor_specs[i]; + t.num_elements = out.num_elements; + t.data_ptr += shard_offset; + copy_advanced(out, t); +} + +// initialise master weights based on the regular weights, taking into account sharding +void init_master_weights(GPT2 *model) { + int num_param_tensors = tensors_start[PARAMETER+1]; + int master_start = tensors_start[PARAMETER_MASTER]; // relies on there being the same number of parameter and master parameter tensors + for (int i = 0; i < num_param_tensors; i++) { + TensorGPU master = tensor_specs[i+master_start]; + switch (tensor_specs[i].data_type) { + case DType::FP32: init_tensor_shard(master, i); break; + case DType::FP16: init_tensor_shard<__nv_half>(master, i); break; + case DType::BF16: init_tensor_shard<__nv_bfloat16>(master, i); break; + case DType::FP8E4M3: init_tensor_shard<__nv_fp8_e4m3>(master, i); break; + } + } } void gpt2_write_to_checkpoint(GPT2 *model, const char* checkpoint_path) { @@ -433,13 +513,20 @@ void gpt2_write_to_checkpoint(GPT2 *model, const char* checkpoint_path) { model_header[7] = model->config.padded_vocab_size; fwriteCheck(model_header, sizeof(int), 256, model_file); // write the parameters - device_to_file(model_file, model->params_memory, model->num_parameters_bytes, - IO_BUF_SIZE, main_stream); + if (write_as_floatX && model->num_parameters_bytes != model->num_parameters * sizeof(floatX)) { + // convert the parameters to floatX before writing them + assert(tensors_bytes[MULTIUSE] >= model->num_parameters * sizeof(floatX)); // todo - make this always work + convert_to_fixed_parameters(model, model->tensor_memory[MULTIUSE]); + device_to_file(model_file, model->tensor_memory[MULTIUSE], model->num_parameters * sizeof(floatX), IO_BUF_SIZE); + } else { + // just write the parameters as they are + device_to_file(model_file, model->tensor_memory[PARAMETER], model->num_parameters_bytes, IO_BUF_SIZE); + } // close file, we're done fcloseCheck(model_file); } -void gpt2_build_from_checkpoint(GPT2 *model, const char* checkpoint_path, bool weight_init=true) { +void gpt2_build_from_checkpoint(GPT2 *model, const char* checkpoint_path) { // If weight_init is true, we will load the weights from this checkpoint .bin file // We sometimes want this to be false, if we are going to initialize these weights from // the master weights that are instead stored in the state .bin file. @@ -467,19 +554,18 @@ void gpt2_build_from_checkpoint(GPT2 *model, const char* checkpoint_path, bool w exit(EXIT_FAILURE); } - // check if the precision mode of the checkpoing matches the model precision - if (weight_init) { - if (PRECISION_MODE == PRECISION_BF16 && version != 5) { - fprintf(stderr, "Precision is configured as BF16 but model at %s is not.\n", checkpoint_path); - fprintf(stderr, "---> HINT: are you sure you're loading a _bf16.bin file?\n"); - exit(EXIT_FAILURE); - } - if (PRECISION_MODE == PRECISION_FP32 && version != 3) { - fprintf(stderr, "Precision is configured as FP32 but model at %s is not.\n", checkpoint_path); - fprintf(stderr, "---> HINT: to turn on FP32 you have to compile like: `make train_gpt2cu PRECISION=FP32`\n"); - fprintf(stderr, "---> HINT: are you sure you're loading a .bin file without any _bf16 in the name?\n"); - exit(EXIT_FAILURE); - } + // check if the precision mode of the checkpoint matches the model precision + // todo - we could support this (and FP16) fairly easily by modifying convert_fixed_parameters() a bit... + if (PRECISION_MODE == PRECISION_BF16 && version != 5) { + fprintf(stderr, "Precision is configured as BF16 but model at %s is not.\n", checkpoint_path); + fprintf(stderr, "---> HINT: are you sure you're loading a _bf16.bin file?\n"); + exit(EXIT_FAILURE); + } + if (PRECISION_MODE == PRECISION_FP32 && version != 3) { + fprintf(stderr, "Precision is configured as FP32 but model at %s is not.\n", checkpoint_path); + fprintf(stderr, "---> HINT: to turn on FP32 you have to compile like: `make train_gpt2cu PRECISION=FP32`\n"); + fprintf(stderr, "---> HINT: are you sure you're loading a .bin file without any _bf16 in the name?\n"); + exit(EXIT_FAILURE); } // read in hyperparameters @@ -490,17 +576,26 @@ void gpt2_build_from_checkpoint(GPT2 *model, const char* checkpoint_path, bool w model->config.channels = model_header[6]; model->config.padded_vocab_size = model_header[7]; - // allocate memory for the model parameters - gpt2_allocate_weights(model); + // key line to allocate all of the GPU buffers for all of the tensors + gpt2_allocate(model); - // read in the parameters if weight_init is true - if (weight_init) { - assert(model->params_memory != NULL); - file_to_device(model->params_memory, model_file, model->num_parameters_bytes, IO_BUF_SIZE, main_stream); + // if the number of bytes in the checkpoint doesn't match the number of bytes allocated, + // we assume the checkpoint is all floatX but our model has different sizes for different tensors (e.g. FP8) + fseek(model_file, 0, SEEK_END); + size_t checkpoint_bytes = ftell(model_file) - sizeof(model_header); + fseek(model_file, sizeof(model_header), SEEK_SET); + + if (checkpoint_bytes != model->num_parameters_bytes) { + assert(checkpoint_bytes == model->num_parameters * sizeof(floatX)); + assert(checkpoint_bytes <= tensors_bytes[MULTIUSE]); // todo - won't work if params size > activations size + file_to_device(model->tensor_memory[MULTIUSE], model_file, checkpoint_bytes, IO_BUF_SIZE); + convert_fixed_parameters(model, model->tensor_memory[MULTIUSE], checkpoint_bytes); + } else { + file_to_device(model->tensor_memory[PARAMETER], model_file, model->num_parameters_bytes, IO_BUF_SIZE); } - fcloseCheck(model_file); // only return from this function once we are certain the params are ready on the GPU + fcloseCheck(model_file); cudaCheck(cudaDeviceSynchronize()); } @@ -570,176 +665,123 @@ void gpt_build_from_descriptor(GPT2 *model, const char* descriptor) { model->config.vocab_size = 50257; model->config.padded_vocab_size = 50304; // padded to 128 for CUDA kernel efficiency - gpt2_allocate_weights(model); + gpt2_allocate(model); // allocate and random init the memory for all the parameters with GPT-2 schema // weights ~N(0, 0.02), biases 0, c_proj weights ~N(0, 0.02/(2*L)**0.5) // NOTE: assuming all parameters are of the type floatX, could be relaxed later mt19937_state init_rng; manual_seed(&init_rng, 42); - floatX* params_memory_cpu = (floatX*)mallocCheck(model->num_parameters_bytes); - memset(params_memory_cpu, 0, model->num_parameters_bytes); + size_t fixed_size_bytes = model->num_parameters * sizeof(floatX); + floatX* params_memory_cpu = (floatX*)mallocCheck(fixed_size_bytes); + memset(params_memory_cpu, 0, fixed_size_bytes); + // fill in all the weights with random values - float residual_scale = 1.0f / sqrtf(2.0f * model->config.num_layers); // we have to init all these tensors exactly in the order that PyTorch initializes them // so that we can match them up and get correctness and exactly the same initial conditions - size_t L = model->config.num_layers; + float residual_scale = 1.0f / sqrtf(2.0f * model->config.num_layers); size_t offset = 0; - for (int l = 0; l < L; l++) { - offset = 0; - for (int i = 0; i < NUM_PARAMETER_TENSORS; i++) { - // the layernorm parameters are all initialized to 1 - if (l == 0 && (i == 2 || i == 8 || i == 14)) { // only at l = 0 to init these just once - for (size_t j = 0; j < model->param_elements[i]; j++) { - params_memory_cpu[offset + j] = 1.0f; - } + + int num_param_tensors = tensors_start[PARAMETER+1]; + for (int i = 0; i < num_param_tensors; i++) { + TensorSpec tensor = tensor_specs[i]; + if ((tensor.tensor_flags & TFlags::LAYERNORM) && !(tensor.tensor_flags & BIAS)) { + for (size_t j = 0; j < tensor.num_elements; j++) { + params_memory_cpu[offset + j] = (floatX)1.0f; } - // weights tensors are handled here - if ((l == 0 && (i == 0 || i == 1)) // only at l = 0, init the wte and wpe tensors - || i == 4 || i == 6 || i == 10 || i == 12) { - size_t n = model->param_elements[i]; - size_t layer_offset = 0; - if (i == 0) { - // for wte tensor (padded vocab) override to init V instead of Vp rows - n = model->config.vocab_size * model->config.channels; - } - if (i == 4 || i == 6 || i == 10 || i == 12) { - // weight tensors, we are only initializing layer l - assert(n % L == 0); - n = n / L; - layer_offset = l * n; - } - // in GPT-2, the projections back into the residual stream are additionally - // scaled by 1/sqrt(2*L) for training stability - float scale = (i == 6 || i == 12) ? 0.02f * residual_scale : 0.02f; - // okay let's draw the random numbers and write them - float *fp32_buffer = (float*)mallocCheck(n * sizeof(float)); - normal_(fp32_buffer, n, 0.0f, scale, &init_rng); - for (size_t j = 0; j < n; j++) { - params_memory_cpu[offset + layer_offset + j] = (floatX)fp32_buffer[j]; - } - free(fp32_buffer); + } + if (tensor.tensor_flags & TENSOR_2D) { + size_t n = tensor.num_elements; + if (n == model->config.padded_vocab_size * model->config.channels) { + n = model->config.vocab_size * model->config.channels; + } + + // in GPT-2, the projections back into the residual stream are additionally + // scaled by 1/sqrt(2*L) for training stability + float scale = 0.02f; + if (strstr(tensor.name, "proj") != NULL) { // todo: yuck - use TFlags! + scale *= residual_scale; } - offset += model->param_elements[i]; + + float *fp32_buffer = (float*)mallocCheck(n * sizeof(float)); + normal_(fp32_buffer, n, 0.0f, scale, &init_rng); + for (size_t j = 0; j < n; j++) { + params_memory_cpu[offset + j] = (floatX)fp32_buffer[j]; + } + free(fp32_buffer); } + offset += tensor.num_elements; + } + + // if the actual allocation doesn't match "params * sizeof(floatX)" we need to convert everything, otherwise just copy. + if (fixed_size_bytes != model->num_parameters_bytes) { + assert(tensors_bytes[MULTIUSE] >= model->num_parameters * sizeof(floatX)); // todo - make this always work + cudaMemcpy(model->tensor_memory[MULTIUSE], params_memory_cpu, fixed_size_bytes, cudaMemcpyHostToDevice); + convert_fixed_parameters(model, model->tensor_memory[MULTIUSE], fixed_size_bytes); + } else { + cudaCheck(cudaMemcpy(model->tensor_memory[PARAMETER], params_memory_cpu, model->num_parameters_bytes, cudaMemcpyHostToDevice)); } - // copy them to GPU - cudaCheck(cudaMemcpy(model->params_memory, params_memory_cpu, model->num_parameters_bytes, cudaMemcpyHostToDevice)); free(params_memory_cpu); } // propagate inputs through the network to produce logits. -// right now, this function is fully synchronous with the host void gpt2_forward(GPT2 *model, const int* inputs, size_t B, size_t T) { NVTX_RANGE_FN(); - // we must be careful and use size_t instead of int, otherwise - // we could overflow int. E.g. l * B * NH * T * T overflows int at B 16. - - // ensure the model was initialized or error out - if (model->params_memory == NULL) { - printf("Error: model was not initialized properly.\n"); - exit(EXIT_FAILURE); - } - - // convenience parameters const size_t V = model->config.vocab_size; const size_t Vp = model->config.padded_vocab_size; const size_t L = model->config.num_layers; const size_t NH = model->config.num_heads; const size_t C = model->config.channels; - // validate B,T are not larger than the values used at initialisation - // (smaller B,T are okay for inference only) + // validate B,T are not larger than the values at initialisation (smaller is OK for inference) if (B > model->batch_size || T > model->seq_len) { printf("Model: B=%d T=%d, Desired: B=%d T=%d\n", model->batch_size, model->seq_len, (int)B, (int)T); exit(EXIT_FAILURE); } - - // copy inputs/targets to the model - cudaCheck(cudaMemcpy(model->inputs, inputs, B * T * sizeof(int), cudaMemcpyHostToDevice)); + // unused parts of attention buffer must be zeroed for non-cuDNN path + if (!CUDNN_ENABLED && T != model->seq_len) { + cudaCheck(cudaMemset(ACT_0(att), 0, L * B * NH * T * T * sizeof(floatX))); + } // validate inputs, all indices must be in the range [0, V) - // we can do this while the copies are already underway tokenCheck(inputs, B*T, V); - // forward pass - ParameterTensors params = model->params; // for brevity - ActivationTensors acts = model->acts; - encoder_forward(acts.encoded, model->inputs, params.wte, params.wpe, B, T, C, main_stream); // encoding goes into residual[0] - - // first layernorm isn't fused - layernorm_forward((model->recompute < 2) ? acts.ln1 : acts.lnf, acts.ln1_mean, acts.ln1_rstd, acts.encoded, params.ln1w, params.ln1b, B, T, C, main_stream); + // copy inputs/targets to the model (fully synchronous with the host for now) + cudaCheck(cudaMemcpy(model->inputs, inputs, B * T * sizeof(int), cudaMemcpyHostToDevice)); + // start of forward pass with encoder (layer 0) + int l = 0; + encoder_forward(ACT(encoded), model->inputs, PARAM(wte), PARAM(wpe), B, T, C); + layernorm_forward(ACT(ln1), ACT(ln1_mean), ACT(ln1_rstd), ACT(encoded), PARAM(ln1w), PARAM(ln1b), B*T, C); - for (int l = 0; l < L; l++) { + for (; l < L; l++) { NvtxRange layer_range("Layer", l); + tensorX residual = (l == 0) ? ACT(encoded) : ACT_L(residual3, l-1); + + tensorX qkvr = ACT(qkvr); // non-cudnn reuses tensor with different memory pre/post-permute + qkvr.data_ptr = CUDNN_ENABLED ? ACT(qkvr) : MULTI(output_scratch); + matmul_forward(qkvr, ACT(ln1), PARAM(qkvw), PARAM(qkvb), B*T, C, 3*C); - floatX* residual = l == 0 ? acts.encoded : acts.residual3 + (l-1) * B * T * C; - - // get the pointers of the weights for this layer - floatX* l_qkvw = params.qkvw + l * 3*C * C; - floatX* l_qkvb = params.qkvb + l * 3*C; - floatX* l_attprojw = params.attprojw + l * C * C; - floatX* l_attprojb = params.attprojb + l * C; - floatX* l_ln2w = params.ln2w + l * C; - floatX* l_ln2b = params.ln2b + l * C; - floatX* l_fcw = params.fcw + l * 4*C * C; - floatX* l_fcb = params.fcb + l * 4*C; - floatX* l_fcprojw = params.fcprojw + l * C * 4*C; - floatX* l_fcprojb = params.fcprojb + l * C; - - // get the pointers of the activations for this layer - floatX* l_ln1 = (model->recompute < 2) ? acts.ln1 + l * B * T * C : acts.lnf; - floatX* l_qkvr = acts.qkvr + l * B * T * 3*C; - floatX* l_atty = acts.atty + l * B * T * C; - floatX* l_residual2 = acts.residual2 + l * B * T * C; - floatX* l_ln2 = (model->recompute < 2) ? acts.ln2 + l * B * T * C : acts.lnf; - float* l_ln2_mean = acts.ln2_mean + l * B * T; - float* l_ln2_rstd = acts.ln2_rstd + l * B * T; - floatX* l_fch = acts.fch + l * B * T * 4*C; - // reuse the same activation buffer at each layer, as we'll re-compute the gelu during backward - // very useful because we dramatically reduce VRAM usage, and may be able to fit larger batch size - floatX* l_fch_gelu = (model->recompute < 1) ? acts.fch_gelu + l * B * T * 4*C : acts.fch_gelu; - floatX* l_residual3 = acts.residual3 + l * B * T * C; - floatX* scratch = (floatX*)acts.output; // used for non-cudnn attention, fcproj, attproj, etc. - - // now do the forward pass #ifdef ENABLE_CUDNN - float* l_att = (float*)acts.att + l * B * NH * T; // cuDNN needs a smaller FP32 tensor - matmul_forward_cublaslt(l_qkvr, l_ln1, l_qkvw, l_qkvb, B, T, C, 3*C, main_stream); - attention_forward_cudnn(l_atty, (float*)l_att, l_qkvr, B, T, NH, C, main_stream); + attention_forward_cudnn(ACT(atty), ACT(att), ACT(qkvr), B, T, NH, C, main_stream); #else - floatX* l_att = acts.att + l * B * NH * T * T; - if (T != model->seq_len) { // unused parts of attention buffer must be zeroed (T-dependent) - cudaCheck(cudaMemset(l_att, 0, B * NH * T * T * sizeof(floatX))); - } - // these are only needed as scratchpads for the forward pass, but - // need not be stored for backward - matmul_forward_cublaslt(scratch, l_ln1, l_qkvw, l_qkvb, B, T, C, 3*C, main_stream); - attention_forward(l_atty, l_qkvr, l_att, scratch, B, T, C, NH, main_stream); + attention_forward(ACT(atty), ACT(qkvr), ACT(att), qkvr, B, T, C, NH); #endif - matmul_forward_cublaslt(scratch, l_atty, l_attprojw, l_attprojb, B, T, C, C, main_stream); - fused_residual_forward5(l_residual2, l_ln2, l_ln2_mean, l_ln2_rstd, residual, scratch, l_ln2w, l_ln2b, B*T, C, main_stream); - matmul_forward_cublaslt(l_fch_gelu, l_ln2, l_fcw, l_fcb, B, T, C, 4*C, main_stream, l_fch, model->gelu_fusion); - matmul_forward_cublaslt(scratch, l_fch_gelu, l_fcprojw, l_fcprojb, B, T, 4*C, C, main_stream); - // OK, fusion across blocks. + matmul_forward(ACT(attproj), ACT(atty), PARAM(attprojw), PARAM(attprojb), B*T, C, C); + fused_residual_forward5(ACT(residual2), ACT(ln2), ACT(ln2_mean), ACT(ln2_rstd), residual, ACT(attproj), PARAM(ln2w), PARAM(ln2b), B*T, C); + matmul_forward(ACT(fch_gelu), ACT(ln2), PARAM(fcw), PARAM(fcb), B*T, C, 4*C, ACT(fch), model->gelu_fusion); + matmul_forward(ACT(fcproj), ACT(fch_gelu), PARAM(fcprojw), PARAM(fcprojb), B*T, 4*C, C); + if(l+1 != L) { - floatX* l_ln1 = (model->recompute < 2) ? acts.ln1 + (l + 1) * B * T * C : acts.lnf; - float* l_ln1_mean = acts.ln1_mean + (l + 1) * B * T; - float* l_ln1_rstd = acts.ln1_rstd + (l + 1) * B * T; - const floatX* l_ln1w = params.ln1w + (l + 1) * C; - const floatX* l_ln1b = params.ln1b + (l + 1) * C; - fused_residual_forward5(l_residual3, l_ln1, l_ln1_mean, l_ln1_rstd, l_residual2, scratch, l_ln1w, l_ln1b, - B * T, C, main_stream); + fused_residual_forward5(ACT(residual3), ACT_L(ln1, l+1), ACT_L(ln1_mean, l+1), ACT_L(ln1_rstd, l+1), ACT(residual2), ACT(fcproj), + PARAM_L(ln1w, l+1), PARAM_L(ln1b, l+1), B*T, C); } else { - fused_residual_forward5(l_residual3, acts.lnf, acts.lnf_mean, acts.lnf_rstd, l_residual2, scratch, - params.lnfw, params.lnfb, - B * T, C, main_stream); + fused_residual_forward5(ACT(residual3), ACT(lnf), ACT(lnf_mean), ACT(lnf_rstd), ACT(residual2), ACT(fcproj), + PARAM(lnfw), PARAM(lnfb), B*T, C); } } - matmul_forward_cublaslt(acts.output, acts.lnf, params.wte, NULL, B, T, C, Vp, main_stream); - cudaCheck(cudaDeviceSynchronize()); + matmul_forward(ACT(output), ACT(lnf), PARAM(wte), null_tensorX, B*T, C, Vp); } @@ -755,16 +797,15 @@ float gpt2_validate(GPT2 *model, const int* inputs, const int* targets, size_t B const size_t Vp = model->config.padded_vocab_size; NvtxRange classifier_and_loss_range("classifier_and_loss"); - ActivationTensors acts = model->acts; float mean_loss = 0.0f; // fused classifier: does the forward pass and first part of the backward pass const float dloss = 1.0f / (B * T); // results in the uniform average loss over all elements // note: we don't need to generate dlogits here - cudaCheck(cudaMemset(acts.losses, 0, B*T*sizeof(float))); + cudaCheck(cudaMemset(ACT_0(losses), 0, B*T*sizeof(float))); cudaCheck(cudaMemcpy(model->targets, targets, B * T * sizeof(int), cudaMemcpyHostToDevice)); tokenCheck(targets, B*T, V); // while the memcpy is underway, validate the targets - fused_classifier(acts.output, acts.losses, dloss, model->targets, B, T, V, Vp, False, main_stream); - cudaCheck(cudaMemcpy(model->cpu_losses, acts.losses, B * T * sizeof(float), cudaMemcpyDeviceToHost)); + fused_classifier(ACT_0(output), ACT_0(output), ACT_0(losses), dloss, model->targets, B*T, V, Vp, False); + cudaCheck(cudaMemcpy(model->cpu_losses, ACT_0(losses), B * T * sizeof(float), cudaMemcpyDeviceToHost)); for (int i = 0; i < B*T; i++) { mean_loss += model->cpu_losses[i]; } @@ -774,22 +815,8 @@ float gpt2_validate(GPT2 *model, const int* inputs, const int* targets, size_t B } void gpt2_backward_and_reduce(GPT2 *model, int* inputs, const int* targets, int grad_accum_steps, int micro_step) { - if(model->grads_memory == nullptr) { - fprintf(stderr, "Need to allocate gradients before backward"); - exit(EXIT_FAILURE); - } NVTX_RANGE_FN(); - bool last_step = micro_step == grad_accum_steps - 1; - // on the first micro-step zero the gradients, as we're about to += accumulate into them - if (micro_step == 0) { - // there are currently two state vars during the gradient accumulation inner loop: - // 1) the losses accumulate += into acts.losses, reset here - // 2) the gradients accumulate += into grads_memory, reset here - cudaCheck(cudaMemsetAsync(model->acts.losses, 0, model->batch_size * model->seq_len * sizeof(float), main_stream)); - cudaCheck(cudaMemsetAsync(model->grads_memory, 0, model->num_parameters * sizeof(floatX), main_stream)); - } - - // convenience shortcuts, size_t instead of int so that pointer arithmetics don't overflow + // convenience shortcuts (size_t instead of int so that pointer arithmetics don't overflow) const size_t B = model->batch_size; const size_t T = model->seq_len; const size_t V = model->config.vocab_size; @@ -797,131 +824,82 @@ void gpt2_backward_and_reduce(GPT2 *model, int* inputs, const int* targets, int const size_t L = model->config.num_layers; const size_t NH = model->config.num_heads; const size_t C = model->config.channels; + int l = L-1; // start from the last layer - ParameterTensors params = model->params; // for brevity - ParameterTensors grads = model->grads; - ActivationTensors acts = model->acts; + bool last_step = micro_step == grad_accum_steps - 1; + // on the first micro-step zero the gradients, as we're about to += accumulate into them + if (micro_step == 0) { + // there are currently two state vars during the gradient accumulation inner loop: + // 1) the losses accumulate += into acts.losses, reset here + // 2) the gradients accumulate += into grads_memory, reset here + cudaCheck(cudaMemsetAsync(ACT(losses), 0, B * T * sizeof(float), main_stream)); + cudaCheck(cudaMemsetAsync(model->tensor_memory[PARAMETER_GRAD], 0, tensors_bytes[PARAMETER_GRAD], main_stream)); + } // accumulate the losses inside acts.losses, and kick off the backward pass inside the fused classifier NvtxRange classifier_and_loss_range("classifier_and_loss"); const float dloss = 1.0f / (float)(B * T * grad_accum_steps); // results in the uniform average loss over all elements cudaCheck(cudaMemcpy(model->targets, targets, B * T * sizeof(int), cudaMemcpyHostToDevice)); tokenCheck(targets, B*T, V); - fused_classifier(acts.output, acts.losses, dloss, model->targets, B, T, V, Vp, True, main_stream); + fused_classifier(AGRAD(output), ACT(output), ACT(losses), dloss, model->targets, B*T, V, Vp, True); // todo - split output & doutput - // backward pass: go in the reverse order of the forward pass, and call backward() functions - - // reset residual stream gradients (put here to work with gradient accumulation) - floatX* dresidual = (floatX*)model->acts.scratch_btc; // the main buffer holding the gradient in the backward pass - cudaCheck(cudaMemset(dresidual, 0, B * T * C * sizeof(floatX))); - - // re-use the output buffer of the forward pass as a scratchpad during backward pass - float* scratchF = (float*)acts.output; - floatX* scratchX = (floatX*)acts.output; + // re-use the output buffer of the forward pass as a scratchpad during backward pass + dedicated buffer + tensor32 scratchF_HUGE = MULTI_0(output_scratch_fp32); // Largest buffer imaginable (max of output & everything else) + tensorX scratchX_HUGE = MULTI_0(output_scratch); + tensor32 scratchF = MULTI_0(local_scratch_fp32); // FP32 BTC with cuDNN, FP32 2*BTC without cuDNN (i.e. 4xBTC BF16) + tensorX scratchX = MULTI_0(local_scratch); + // backward pass: go in the reverse order of the forward pass, and call backward() functions // we kick off the chain rule by filling in dlosses with 1.0f/(B*T) // this was done in the fused classifier kernel as last step of forward pass // technically that is a small, inline backward() pass of calculating // total, final loss as the mean over all losses over all (B,T) positions in the batch // next: backward the classifier matmul - matmul_backward(model->acts.scratch_bt4c, grads.wte, NULL, acts.output, acts.lnf, params.wte, NULL, B, T, C, Vp, main_stream); + matmul_backward(AGRAD(lnf), PGRAD(wte), null_tensorX, AGRAD(output), ACT(lnf), PARAM(wte), scratchF, B*T, C, Vp); // backward the final layernorm - floatX* residual = acts.residual3 + (L-1) * B * T * C; // last residual is in residual3 - layernorm_backward(dresidual, grads.lnfw, grads.lnfb, scratchF, model->acts.scratch_bt4c, residual, params.lnfw, acts.lnf_mean, acts.lnf_rstd, B, T, C, main_stream); - - // from this point on, we no longer need the values stored in the last residual, so we can reuse that memory as generic - // scratch for backward computations - floatX* dl_btc = residual; + layernorm_backward(AGRAD_L(residual3, L-1), null_tensorX, PGRAD(lnfw), PGRAD(lnfb), scratchF, (tensorX)AGRAD(lnf), ACT_L(residual3, L-1), + PARAM(lnfw), ACT(lnf_mean), ACT(lnf_rstd), B*T, C); // now backward all the layers - for (int l = L-1; l >= 0; l--) { + for (; l >= 0; l--) { NvtxRange layer_range("Layer", l); + tensorX residual = (l == 0) ? ACT(encoded) : ACT_L(residual3, l-1); + tensorX dresidual = (l == 0) ? AGRAD(encoded) : AGRAD_L(residual3, l-1); - residual = l == 0 ? acts.encoded : acts.residual3 + (l-1) * B * T * C; - - // get the pointers of the weights for this layer - floatX* l_ln1w = params.ln1w + l * C; - floatX* l_ln1b = params.ln1b + l * C; - floatX* l_qkvw = params.qkvw + l * 3*C * C; - floatX* l_attprojw = params.attprojw + l * C * C; - floatX* l_ln2w = params.ln2w + l * C; - floatX* l_ln2b = params.ln2b + l * C; - floatX* l_fcw = params.fcw + l * 4*C * C; - floatX* l_fcprojw = params.fcprojw + l * C * 4*C; - // get the pointers of the gradients of the weights for this layer - floatX* dl_ln1w = grads.ln1w + l * C; - floatX* dl_ln1b = grads.ln1b + l * C; - floatX* dl_qkvw = grads.qkvw + l * 3*C * C; - floatX* dl_qkvb = grads.qkvb + l * 3*C; - floatX* dl_attprojw = grads.attprojw + l * C * C; - floatX* dl_attprojb = grads.attprojb + l * C; - floatX* dl_ln2w = grads.ln2w + l * C; - floatX* dl_ln2b = grads.ln2b + l * C; - floatX* dl_fcw = grads.fcw + l * 4*C * C; - floatX* dl_fcb = grads.fcb + l * 4*C; - floatX* dl_fcprojw = grads.fcprojw + l * C * 4*C; - floatX* dl_fcprojb = grads.fcprojb + l * C; - // get the pointers of the activations for this layer - floatX* l_ln1 = (model->recompute < 2) ? acts.ln1 + l * B * T * C : acts.lnf; - float* l_ln1_mean = acts.ln1_mean + l * B * T; - float* l_ln1_rstd = acts.ln1_rstd + l * B * T; - floatX* l_qkvr = acts.qkvr + l * B * T * 3*C; - floatX* l_atty = acts.atty + l * B * T * C; - floatX* l_residual2 = acts.residual2 + l * B * T * C; - floatX* l_ln2 = (model->recompute < 2) ? acts.ln2 + l * B * T * C : acts.lnf; - float* l_ln2_mean = acts.ln2_mean + l * B * T; - float* l_ln2_rstd = acts.ln2_rstd + l * B * T; - floatX* l_fch_pre_gelu = acts.fch + l * B * T * 4*C; - floatX* l_fch_gelu = (model->recompute < 1) ? acts.fch_gelu + l * B * T * 4*C : acts.fch_gelu; - // get the pointers of the gradients of the activations for this layer - // notice that there is no l *, because we just have a single copy, and keep - // re-using this memory in every Transformer block as we calculate backward pass - - floatX* dl_bt4c = (floatX*)model->acts.scratch_bt4c; - - // start the backward pass for this layer - if(model->recompute >= 1) { - // recompute >= 1 means we recompute gelu. in this case, - // l_fch_gelu is just a buffer, so re-compute the gelu from l_fch here - gelu_forward(l_fch_gelu, l_fch_pre_gelu, B*T*4*C, main_stream); + if(model->recompute >= 1) { // recompute >= 1 means we recompute gelu + gelu_forward_fp8(ACT(fch_gelu), ACT(fch)); } - matmul_backward(dl_bt4c, dl_fcprojw, dl_fcprojb, dresidual, l_fch_gelu, l_fcprojw, scratchF, B, T, 4*C, C, main_stream, l_fch_pre_gelu, model->gelu_fusion); - if(model->recompute >= 2) { - // same as gelu above, l_ln1 and l_ln2 are just buffers if recompute >= 2, recompute them here on demand - layernorm_forward(l_ln2, l_ln2_mean, l_ln2_rstd, l_residual2, l_ln2w, l_ln2b, B, T, C, main_stream); + matmul_backward_fp8(AGRAD(fch), PGRAD(fcprojw), PGRAD(fcprojb), (tensorX)AGRAD(residual3), ACT(fch_gelu), PARAM(fcprojw), scratchF, scratchF_HUGE, B*T, 4*C, C, ACT(fch)); + + if(model->recompute >= 2) { // recompute >= 2 means we recompute layernorm + layernorm_forward((tensor8)ACT(ln2), ACT(ln2_mean), ACT(ln2_rstd), ACT(residual2), PARAM(ln2w), PARAM(ln2b), B*T, C); } - matmul_backward(dl_btc, dl_fcw, dl_fcb, dl_bt4c, l_ln2, l_fcw, scratchF, B, T, C, 4 * C, main_stream); - // layernorm backward does += to the dresidual, so it correctly accumulates grad from the MLP block above - layernorm_backward(dresidual, dl_ln2w, dl_ln2b, scratchF, dl_btc, l_residual2, l_ln2w, l_ln2_mean, l_ln2_rstd, B, T, C, main_stream); - matmul_backward(dl_btc, dl_attprojw, dl_attprojb, dresidual, l_atty, l_attprojw, scratchF, B, T, C, C, main_stream); + matmul_backward_fp8(AGRAD(ln2), PGRAD(fcw), PGRAD(fcb), (tensor8e5)AGRAD(fch), ACT(ln2), PARAM(fcw), scratchF, scratchF_HUGE, B*T, C, 4 * C); + layernorm_backward(AGRAD(residual2), AGRAD(residual3), PGRAD(ln2w), PGRAD(ln2b), scratchF, (tensor8e5)AGRAD(ln2), ACT(residual2), PARAM(ln2w), ACT(ln2_mean), ACT(ln2_rstd), B*T, C); + // AGRAD(atty) is BF16, AGRAD(residual2) is BF16, ACT(atty) is BF16, PARAM(attprojw) is BF16... ==> 100% BF16 ==> keep BF16 for now! + matmul_backward(AGRAD(atty), PGRAD(attprojw), PGRAD(attprojb), AGRAD(residual2), ACT(atty), PARAM(attprojw), scratchF, B*T, C, C); #ifdef ENABLE_CUDNN - float* l_att = (float*)acts.att + l * B * NH * T; // cuDNN needs a smaller FP32 tensor - attention_backward_cudnn(dl_bt4c, dl_btc, l_qkvr, l_atty, (float*)l_att, B, T, NH, C, main_stream); + attention_backward_cudnn(AGRAD(qkvr), AGRAD(atty), ACT(qkvr), ACT(atty), ACT(att), B, T, NH, C, main_stream); #else - floatX* l_att = acts.att + l * B * NH * T * T; - // we need B x T x (4)C buffers. l_atty and l_fch aren't needed anymore at this point, so reuse their memory - floatX* buffer_a = l_atty; - floatX* buffer_b = l_fch_pre_gelu; // this is B x T x 4C, so even larger than what we need - attention_backward(dl_bt4c, buffer_b, scratchX, buffer_a, dl_btc, l_qkvr, l_att, B, T, C, NH, main_stream); + attention_backward(AGRAD(qkvr), scratchX, scratchX_HUGE, AGRAD(atty), ACT(qkvr), ACT(att), B, T, C, NH); #endif + if(model->recompute >= 2) { - layernorm_forward(l_ln1, l_ln1_mean, l_ln1_rstd, residual, l_ln1w, l_ln1b, B, T, C, main_stream); + layernorm_forward((tensor8)ACT(ln1), ACT(ln1_mean), ACT(ln1_rstd), residual, PARAM(ln1w), PARAM(ln1b), B*T, C); } - // QKV parameter gradients - matmul_backward(dl_btc, dl_qkvw, dl_qkvb, dl_bt4c, l_ln1, l_qkvw, scratchF, B, T, C, 3 * C, main_stream); - // layernorm backward does += to dresidual, so it correctly accumulates gradient for the Attention block above - layernorm_backward(dresidual, dl_ln1w, dl_ln1b, scratchF, dl_btc, residual, l_ln1w, l_ln1_mean, l_ln1_rstd, B, T, C, main_stream); + matmul_backward_fp8(AGRAD(ln1), PGRAD(qkvw), PGRAD(qkvb), (tensorX)AGRAD(qkvr), ACT(ln1), PARAM(qkvw), scratchF, scratchF_HUGE, B*T, C, 3 * C); + layernorm_backward(dresidual, AGRAD(residual2), PGRAD(ln1w), PGRAD(ln1b), scratchF, (tensor8e5)AGRAD(ln1), residual, PARAM(ln1w), ACT(ln1_mean), ACT(ln1_rstd), B*T, C); // Accumulate gradients from this layer in a background stream. if(last_step) { floatX* const pointers[] = { - dl_ln1w, dl_ln1b, - dl_qkvw, dl_qkvb, - dl_attprojw, dl_attprojb, - dl_ln2w, dl_ln2b, - dl_fcw, dl_fcb, - dl_fcprojw, dl_fcprojb + PGRAD(ln1w), PGRAD(ln1b), + PGRAD(qkvw), PGRAD(qkvb), + PGRAD(attprojw), PGRAD(attprojb), + PGRAD(ln2w), PGRAD(ln2b), + PGRAD(fcw), PGRAD(fcb), + PGRAD(fcprojw), PGRAD(fcprojb) }; const size_t nelem[] = { C, C, @@ -933,21 +911,51 @@ void gpt2_backward_and_reduce(GPT2 *model, int* inputs, const int* targets, int }; multi_gpu_async_reduce_gradient(pointers, nelem, &multi_gpu_config, main_stream); } + + // todo - this used to be bit-for-bit identical to not recomputing forward, why is it now different?! + // Is it time to redo the forward pass from our activation checkpoints? + if (LAYERS_PER_ACTIVATION_CHECKPOINT && (l % max(1, LAYERS_PER_ACTIVATION_CHECKPOINT)) == 0 && l > 0) { + int backward_l = l; + l -= LAYERS_PER_ACTIVATION_CHECKPOINT; + for (int i = 0; i < LAYERS_PER_ACTIVATION_CHECKPOINT; i++, l++) { + // non-fused layernorm as we already (only!) have the residual + // (for the original forward pass, residual of l-1 is fused with layernorm of l) + tensorX residual = (l == 0) ? ACT(encoded) : ACT_L(residual3, l-1); + layernorm_forward(ACT(ln1), ACT(ln1_mean), ACT(ln1_rstd), residual, PARAM(ln1w), PARAM(ln1b), B*T, C); + + tensorX qkvr = ACT(qkvr); // non-cudnn reuses tensor with different memory pre/post-permute + qkvr.data_ptr = CUDNN_ENABLED ? ACT(qkvr) : MULTI(output_scratch); + matmul_forward(qkvr, ACT(ln1), PARAM(qkvw), PARAM(qkvb), B*T, C, 3*C); + + #ifdef ENABLE_CUDNN + attention_forward_cudnn(ACT(atty), ACT(att), ACT(qkvr), B, T, NH, C, main_stream); + #else + attention_forward(ACT(atty), ACT(qkvr), ACT(att), qkvr, B, T, C, NH); + #endif + + matmul_forward(ACT(attproj), ACT(atty), PARAM(attprojw), PARAM(attprojb), B*T, C, C); + fused_residual_forward5(ACT(residual2), ACT(ln2), ACT(ln2_mean), ACT(ln2_rstd), residual, ACT(attproj), PARAM(ln2w), PARAM(ln2b), B*T, C); + matmul_forward(ACT(fch_gelu), ACT(ln2), PARAM(fcw), PARAM(fcb), B*T, C, 4*C, ACT(fch), model->gelu_fusion); + matmul_forward(ACT(fcproj), ACT(fch_gelu), PARAM(fcprojw), PARAM(fcprojb), B*T, 4*C, C); + } + l = backward_l; + } } - encoder_backward(grads.wte, grads.wpe, scratchX, model->workload_indices, model->bucket_info, - dresidual, model->inputs, inputs, B, T, C, random_u32(&model->rng_state), main_stream); + + encoder_backward(PGRAD(wte), PGRAD(wpe), scratchX_HUGE, model->workload_indices, model->bucket_info, + AGRAD(encoded), model->inputs, inputs, B, T, C, random_u32(&model->rng_state)); // Aggregate all gradients that are not part of the transformer blocks if(last_step) { // reduce all the losses within the current GPU (across all microsteps) - global_sum_deterministic(model->accumulated_mean_loss, acts.losses, B*T, main_stream); + global_sum_deterministic(model->accumulated_mean_loss, ACT(losses), B*T, main_stream); // reduce loss across GPUs to a single, final float across all microsteps and GPUs #if MULTI_GPU ncclCheck(ncclAllReduce(model->accumulated_mean_loss, model->accumulated_mean_loss, sizeof(float), ncclFloat, ncclAvg, multi_gpu_config.nccl_comm, main_stream)); #endif cudaCheck(cudaMemcpyAsync(&model->mean_loss, model->accumulated_mean_loss, sizeof(float), cudaMemcpyDeviceToHost, main_stream)); // reduce the gradients for non-transformer block parameters - floatX* const pointers[] = {grads.wte, grads.wpe, grads.lnfw, grads.lnfb}; + floatX* const pointers[] = {PGRAD(wte), PGRAD(wpe), PGRAD(lnfw), PGRAD(lnfb)}; const size_t nelem[] = {Vp * C, T * C, C, C}; multi_gpu_async_reduce_gradient(pointers, nelem, &multi_gpu_config, main_stream); } @@ -960,68 +968,29 @@ void gpt2_backward_and_reduce(GPT2 *model, int* inputs, const int* targets, int } } -// Gets the offset of a specific tensor for a specific layer in the GPT2 model -// layer_id is ignored for weights that are not part of a transformer block -ShardInfo gpt2_get_tensor_at_layer(const GPT2 *model, int layer_id, int param_tensor_id) { - // first offset our way to the parameter tensor start - ptrdiff_t offset = 0; - for (int i = 0; i < param_tensor_id; i++) { - offset += (ptrdiff_t)model->param_elements[i]; - } - size_t size = model->param_elements[param_tensor_id] ; - // if we are in the transformer block, we need to additionally offset by the layer id - if(2 <= param_tensor_id && param_tensor_id <= 13) { - size /= model->config.num_layers; - offset += (ptrdiff_t)(layer_id * size); - } - return {offset, size}; -} - float gpt2_calculate_grad_norm(GPT2 *model, MultiGpuConfig* multi_gpu_config) { NVTX_RANGE_FN(); - floatX* grads_memory = (floatX*)model->grads_memory; - // repurposing this buffer (which isn't needed now) to write grad norm into it - float* grad_norm_squared = (float*)model->acts.output; + float* grad_norm_squared = MULTI_0(output_scratch_fp32); float grad_norm_squared_cpu = 0.0f; - int num_slices[2] = {1, model->config.num_layers}; - int max_num_block_sums = get_max_num_block_sums(num_slices, 2); - if (multi_gpu_config->zero_stage == 1) { - // because of the ncclReduceScatter() in backward, - // grads_memory only contains the averaged gradients at the local shards, - // so we only calculate the grad norm at the grads_memory belonging to the local shards - for (int i = 0; i < NUM_PARAMETER_TENSORS; i++) { - ShardInfo tensor = gpt2_get_tensor_at_layer(model, 0, i); - ShardInfo shard = multi_gpu_get_shard_offset(tensor.size, multi_gpu_config, 1); - ptrdiff_t offset = tensor.offset + shard.offset; - bool is_first_pass = (i == 0); - if((i < 2 || i > 13)) { - global_norm_squared(grad_norm_squared, grads_memory + offset, shard.size, 0, 1, - max_num_block_sums, is_first_pass, main_stream); - } else { - global_norm_squared(grad_norm_squared, grads_memory + offset, shard.size, tensor.size, model->config.num_layers, - max_num_block_sums, is_first_pass, main_stream); - } - } - global_sum_deterministic(grad_norm_squared, grad_norm_squared, max_num_block_sums, main_stream); + // automagically handles everything including sharding for ZeRO 1/2/3 + global_norm_tensors(grad_norm_squared, multi_gpu_config->process_rank, main_stream); + #if MULTI_GPU + if (multi_gpu_config->zero_stage >= 1) { // further sum the (partial) squared norm across all GPUs ncclCheck(ncclAllReduce(grad_norm_squared, grad_norm_squared, sizeof(float), ncclFloat, ncclSum, multi_gpu_config->nccl_comm, main_stream)); -#endif - } else { - // in regular DDP, backward has averaged the gradients across all GPUs - // so each GPU can compute the squared norm over the whole grad vector, with no added comms needed - global_norm_squared(grad_norm_squared, grads_memory, model->num_parameters, 0, 1, max_num_block_sums, true, main_stream); - global_sum_deterministic(grad_norm_squared, grad_norm_squared, max_num_block_sums, main_stream); } +#endif + cudaCheck(cudaMemcpy(&grad_norm_squared_cpu, grad_norm_squared, sizeof(float), cudaMemcpyDeviceToHost)); float grad_norm_cpu = sqrtf(grad_norm_squared_cpu); return grad_norm_cpu; } void gpt2_update(GPT2 *model, float learning_rate, float beta1, float beta2, float eps, float weight_decay, float grad_scale, int t, - MultiGpuConfig* multi_gpu_config, bool init_from_master_only=false) { + MultiGpuConfig* multi_gpu_config) { // update the model parameters using the AdamW optimizer // keep in mind that optimizer sharding (ZeRO-1) assigns different parameters to different GPUs // so we may not be responsible for the entire parameter tensor @@ -1029,86 +998,62 @@ void gpt2_update(GPT2 *model, float learning_rate, float beta1, float beta2, flo // selectively weight decay some, but not all tensors :( // TODO: revisit and probably refactor this entire function NVTX_RANGE_FN(); - if(model->grads_memory == nullptr || model->m_memory == nullptr || model->v_memory == nullptr) { + if(model->tensor_memory[PARAMETER] == nullptr || model->tensor_memory[PARAMETER_OPT_M] == nullptr || model->tensor_memory[PARAMETER_OPT_V] == nullptr) { fprintf(stderr, "Need to allocate optimizer state before update"); exit(EXIT_FAILURE); } - bool init_state = model->init_state; - if(init_state) { - model->init_state = false; - NvtxRange rng("InitOpt"); - cudaCheck(cudaMemset(model->m_memory, 0, multi_gpu_config->shard_num_parameters * sizeof(float))); - cudaCheck(cudaMemset(model->v_memory, 0, multi_gpu_config->shard_num_parameters * sizeof(float))); - } - // save RNG state at this point so we can round from master weights identically when restoring from a checkpoint model->rng_state_last_update = model->rng_state; - // AdamW update - // handle adamw for all the transformer blocks - for (int i = 0; i < NUM_PARAMETER_TENSORS; i++) { - // generate a unique seed for each tensor - unsigned int seed = random_u32(&model->rng_state); + float beta1_correction = 1.0f - powf(beta1, t); + float beta2_correction = 1.0f - powf(beta2, t); + unsigned int seed = random_u32(&model->rng_state); + int num_shards = tensor_specs[tensors_start[PARAMETER_OPT_M]].num_shards; + int shard_idx = multi_gpu_config->process_rank % num_shards; // todo - currently assuming ZeRO 1 or DPP - int num_layers = model->config.num_layers; - if((i < 2 || i > 13)) { - num_layers = 1; - } + const int block_size = 64; + const int grid_size = deviceProp.maxThreadsPerMultiProcessor * deviceProp.multiProcessorCount / block_size; - ShardInfo tensor = gpt2_get_tensor_at_layer(model, 0, i); - ShardInfo shard = multi_gpu_get_shard_offset(tensor.size, multi_gpu_config, 1); - ptrdiff_t local_offset_full = tensor.offset + shard.offset; - ptrdiff_t local_offset_partial = tensor.offset / multi_gpu_config->num_processes; - - // we only want to weight decay the 2D tensors and leave all 1D tensors alone - // in particular this also decays the embedding weights, but this is ok: - // - the token embeddings are weight shared and participate in the final projection to logits - // - the position embeddings actively participate at every forward/backward pass - float wd = (i == 0 || i == 1 || i == 4 || i == 6 || i == 10 || i == 12) ? weight_decay : 0.0f; - floatX* param_ptr = (floatX*)model->params_memory + local_offset_full; - floatX* grad_ptr = (floatX*)model->grads_memory + local_offset_full; - - ptrdiff_t opt_state_offset = multi_gpu_config->zero_stage < 1 ? local_offset_full : local_offset_partial; - float* m_ptr = model->m_memory + opt_state_offset; - float* v_ptr = model->v_memory + opt_state_offset; - float* master_ptr = nullptr; - if (model->master_weights != nullptr) { master_ptr = model->master_weights + opt_state_offset; } - if(init_state && model->master_weights != nullptr ) { - size_t grid_size = CEIL_DIV(shard.size, 512); - copy_and_cast_kernel<<>>(master_ptr, param_ptr, shard.size, - shard.size, tensor.size); - cudaCheck(cudaGetLastError()); - } + int start_tensor = tensors_start[PARAMETER]; + int last_tensor = tensors_start[PARAMETER+1] - 1; + int num_tensors = last_tensor - start_tensor + 1; - if (init_from_master_only) { - // when resuming training from a checkpoint with master weights (allows changing precision) - init_from_master(param_ptr, master_ptr, shard.size, tensor.size, shard.size, num_layers, seed, main_stream); - } else { - // ok finally call the kernel to update the weights with AdamW - adamw_update(param_ptr, master_ptr, grad_ptr, - m_ptr, v_ptr, - shard.size, tensor.size, tensor.size, shard.size, num_layers, - learning_rate, - beta1, beta2, t, eps, wd, grad_scale, seed, main_stream); - } + if (model->use_master_weights) { + adamw_update_everything<<>>(num_tensors, start_tensor, last_tensor, seed, shard_idx, + learning_rate, beta1, beta2, beta1_correction, beta2_correction, eps, weight_decay, grad_scale, t); + } else { + adamw_update_everything<<>>(num_tensors, start_tensor, last_tensor, seed, shard_idx, + learning_rate, beta1, beta2, beta1_correction, beta2_correction, eps, weight_decay, grad_scale, t); + } - if (multi_gpu_config->zero_stage == 1) { #if MULTI_GPU - ncclCheck(ncclGroupStart()); - for(int l = 0; l < num_layers; ++l) { - // gather updated shards of model->params_memory from each process - ncclCheck(ncclAllGather(param_ptr + l * tensor.size, - (floatX*) model->params_memory + tensor.offset + l * tensor.size, - shard.size, ncclFloatX, - multi_gpu_config->nccl_comm, multi_gpu_config->nccl_stream)); - } - ncclCheck(ncclGroupEnd()); -#endif + if (multi_gpu_config->zero_stage == 1) { + ncclCheck(ncclGroupStart()); + for (int id = 0; id < num_tensors; id++) { + TensorSpec param_tensor = tensor_specs[id]; + TensorSpec opt_tensor = tensor_specs[id + tensors_start[PARAMETER_OPT_M]]; + + size_t sendcount = opt_tensor.num_elements * sizeof_dtype(opt_tensor.data_type); + void* recvbuff = param_tensor.ptr; + void* sendbuff = param_tensor.ptr + (multi_gpu_config->process_rank * sendcount); + + ncclCheck(ncclAllGather(sendbuff, recvbuff, sendcount, ncclFloatX, + multi_gpu_config->nccl_comm, multi_gpu_config->nccl_stream)); } + ncclCheck(ncclGroupEnd()); } - + // combine the absmax of all the GPUs + ncclCheck(ncclAllReduce(gpu_absmax_memory, gpu_absmax_memory, num_tensors, ncclFloat, ncclMax, + multi_gpu_config->nccl_comm, multi_gpu_config->nccl_stream)); +#endif + // todo - smarter synchronization with double buffering etc... cudaCheck(cudaDeviceSynchronize()); + + // update FP8 scale & descale multipliers based on the absmax + // since we just updated the parameters with the old scale, + // the descale of parameters is "delayed" by one step. + update_scales_from_absmax(); } float gpt2_estimate_mfu(GPT2 *model, int num_tokens, float dt) { @@ -1141,12 +1086,9 @@ float gpt2_estimate_mfu(GPT2 *model, int num_tokens, float dt) { } void gpt2_free(GPT2 *model) { - cudaFreeCheck(&model->params_memory); - cudaFreeCheck(&model->grads_memory); - cudaFreeCheck(&model->m_memory); - cudaFreeCheck(&model->v_memory); - cudaFreeCheck(&model->master_weights); - cudaFreeCheck(&model->acts_memory); + for (int i = 0; i < TT::COUNT; i++) { + cudaFreeCheck(&model->tensor_memory[i]); + } cudaFreeCheck(&model->inputs); cudaFreeCheck(&model->targets); cudaFreeCheck(&model->accumulated_mean_loss); @@ -1206,6 +1148,8 @@ void save_state(const char* filename, int step, GPT2* model, DataLoader* loader) state_header[3] = multi_gpu_config.process_rank; // rank of this process state_header[4] = model->use_master_weights; // whether we're using fp32 master weights state_header[5] = loader->should_shuffle; // shuffle state of the dataloader + state_header[6] = num_tensor_specs; // number of tensor specs (must match) + state_header[7] = MAX_ABSMAX_HISTORY; // size of the absmax history (0 = disabled or old version) // int main state, start at 10 to leave some padding state_header[10] = step; // step of the optimization // model rng state, start at 20 to leave some padding @@ -1218,10 +1162,10 @@ void save_state(const char* filename, int step, GPT2* model, DataLoader* loader) // write AdamW m, v, and master_weights here (they are all float) size_t shard_num_parameters = multi_gpu_config.shard_num_parameters; - device_to_file(state_file, model->m_memory, shard_num_parameters * sizeof(float), IO_BUF_SIZE, main_stream); - device_to_file(state_file, model->v_memory, shard_num_parameters * sizeof(float), IO_BUF_SIZE, main_stream); + device_to_file(state_file, model->tensor_memory[PARAMETER_OPT_M], shard_num_parameters * sizeof(float), IO_BUF_SIZE, main_stream); + device_to_file(state_file, model->tensor_memory[PARAMETER_OPT_V], shard_num_parameters * sizeof(float), IO_BUF_SIZE, main_stream); if(model->use_master_weights) { - device_to_file(state_file, model->master_weights, shard_num_parameters * sizeof(float), IO_BUF_SIZE, main_stream); + device_to_file(state_file, model->tensor_memory[PARAMETER_MASTER], shard_num_parameters * sizeof(float), IO_BUF_SIZE, main_stream); } // write dataloader state if we are using the Permuted version of it @@ -1232,6 +1176,11 @@ void save_state(const char* filename, int step, GPT2* model, DataLoader* loader) fwriteCheck(loader->intra_shard_indices, sizeof(int), loader->shard_num_samples, state_file); fwriteCheck(&loader->shuffle_rng, sizeof(mt19937_state), 1, state_file); } + + // write absmax history and scale/descale memory + device_to_file(state_file, gpu_absmax_memory, num_tensor_specs * sizeof(float) * (MAX_ABSMAX_HISTORY + 1), IO_BUF_SIZE, main_stream); + device_to_file(state_file, gpu_scale_memory, num_tensor_specs * sizeof(float) * 2, IO_BUF_SIZE, main_stream); + fcloseCheck(state_file); } @@ -1243,6 +1192,7 @@ void load_state(int* step, GPT2* model, DataLoader* loader, const char* filename assert(state_header[1] == 1); // version number assert(state_header[2] == multi_gpu_config.num_processes); // number of processes assert(state_header[3] == multi_gpu_config.process_rank); // rank of this process + assert(state_header[6] == num_tensor_specs); // number of tensor specs int use_master_weights = state_header[4]; // whether we're using fp32 master weights int should_shuffle = state_header[5]; // shuffle state of the dataloader *step = state_header[10]; // step of the optimization @@ -1250,6 +1200,7 @@ void load_state(int* step, GPT2* model, DataLoader* loader, const char* filename model->rng_state_last_update = *((unsigned long long*)&state_header[22]); // last gpt2_update size_t current_shard_idx = *((size_t*)&state_header[30]); // shard index size_t current_sample_idx = *((size_t*)&state_header[32]); // position in shard + bool restore_absmax_history = (state_header[7] == MAX_ABSMAX_HISTORY); // todo - restore even if not an exact match // read AdamW m, v, master_weights (they are all float) // allocate all the needed memory as necessary @@ -1261,18 +1212,10 @@ void load_state(int* step, GPT2* model, DataLoader* loader, const char* filename exit(EXIT_FAILURE); } - model->init_state = false; // we just got the state from file, no need to do first-touch init - assert(model->m_memory != nullptr); - assert(model->v_memory != nullptr); - file_to_device(model->m_memory, state_file, shard_num_parameters * sizeof(float), IO_BUF_SIZE, main_stream); - file_to_device(model->v_memory, state_file, shard_num_parameters * sizeof(float), IO_BUF_SIZE, main_stream); + file_to_device(model->tensor_memory[PARAMETER_OPT_M], state_file, shard_num_parameters * sizeof(float), IO_BUF_SIZE, main_stream); + file_to_device(model->tensor_memory[PARAMETER_OPT_V], state_file, shard_num_parameters * sizeof(float), IO_BUF_SIZE, main_stream); if(model->use_master_weights) { - assert(model->master_weights != nullptr); - file_to_device(model->master_weights, state_file, shard_num_parameters * sizeof(float), IO_BUF_SIZE, main_stream); - // restore weights from the master weights using the RNG state before last weight update - model->rng_state = model->rng_state_last_update; - gpt2_update(model, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0, &multi_gpu_config, /* init_from_master_only*/ true); - model->rng_state = *((unsigned long long*)&state_header[20]); // use final RNG state from checkpoint after this + file_to_device(model->tensor_memory[PARAMETER_MASTER], state_file, shard_num_parameters * sizeof(float), IO_BUF_SIZE, main_stream); } // revive the DataLoader object and its state @@ -1297,6 +1240,11 @@ void load_state(int* step, GPT2* model, DataLoader* loader, const char* filename } dataloader_resume(loader, current_shard_idx, current_sample_idx); + if (restore_absmax_history) { + file_to_device(gpu_absmax_memory, state_file, num_tensor_specs * sizeof(float) * (MAX_ABSMAX_HISTORY + 1), IO_BUF_SIZE, main_stream); + file_to_device(gpu_scale_memory, state_file, num_tensor_specs * sizeof(float) * 2, IO_BUF_SIZE, main_stream); + } + // all done, close state file fcloseCheck(state_file); } @@ -1559,11 +1507,15 @@ int main(int argc, char *argv[]) { // build the GPT-2 model GPT2 model; gpt2_init_common(&model); + model.use_master_weights = use_master_weights; + model.gelu_fusion = gelu_fusion; + model.recompute = recompute; + model.batch_size = B; + model.seq_len = T; + if (resuming == 1) { // if `-y 1` was set, then we are resuming from the latest checkpoint - // if we are using master weights, we'll init them later inside load_state() - bool weight_init = !use_master_weights; - gpt2_build_from_checkpoint(&model, filename_buffer, weight_init); + gpt2_build_from_checkpoint(&model, filename_buffer); } else if (ends_with_bin(load_filename)) { // otherwise, if this is a .bin file, we assume it's a model, let's init from it gpt2_build_from_checkpoint(&model, load_filename); @@ -1573,9 +1525,6 @@ int main(int argc, char *argv[]) { gpt_build_from_descriptor(&model, load_filename); } - model.use_master_weights = use_master_weights; - model.gelu_fusion = gelu_fusion; - model.recompute = recompute; printf0("| weight init method | %-50s |\n", resuming == 1 ? "intermediate checkpoint" : load_filename); printf0("| max_sequence_length T | %-50d |\n", model.config.max_seq_len); printf0("| vocab_size V | %-50d |\n", model.config.vocab_size); @@ -1660,12 +1609,17 @@ int main(int argc, char *argv[]) { floatX* cpu_logits_raw = (floatX*)mallocCheck(model.config.vocab_size * sizeof(floatX)); float* cpu_logits = (float*)mallocCheck(model.config.vocab_size * sizeof(float)); - // if we found a checkpoint to resume from, load the optimization state + // if we found a checkpoint to resume from, load the optimization state (and initialize it otherwise) int step = 0; - gpt2_allocate_state(&model, B, T); if (resuming == 1) { snprintf(filename_buffer, sizeof(filename_buffer), "%s/state_%08d_%05d.bin", output_log_dir, resume_max_step, multi_gpu_config.process_rank); load_state(&step, &model, &train_loader, filename_buffer); + } else { + cudaCheck(cudaMemset(model.tensor_memory[PARAMETER_OPT_M], 0, multi_gpu_config.shard_num_parameters * sizeof(float))); + cudaCheck(cudaMemset(model.tensor_memory[PARAMETER_OPT_V], 0, multi_gpu_config.shard_num_parameters * sizeof(float))); + if (model.use_master_weights) { + init_master_weights(&model); + } } // init an OutlierDetector the training loss @@ -1757,9 +1711,9 @@ int main(int argc, char *argv[]) { // on cuDNN 9.2.1 with cuDNN FrontEnd 1.5.2, T >= 256 seems bit-for-bit identical // (but even if it wasn't fully identical that's probably not the end of the world) // note this is still somewhat wasteful because we don't have a KV cache! - gpt2_forward(&model, gen_tokens, 1, CEIL_DIV(t, min(T,256)) * min(T,256)); + gpt2_forward(&model, gen_tokens, 1, T); // get the V-dimensional vector probs[0, t-1, :] - floatX* logits = model.acts.output + (t - 1) * model.config.padded_vocab_size; + floatX* logits = ((floatX*)&model.tensor_memory[MULTIUSE][tensor_specs[model.acts.output].offset]) + (t - 1) * model.config.padded_vocab_size; // move probs back to CPU and sample (note we only move the first vocab_size logits, ignoring the padding) cudaCheck(cudaMemcpy(cpu_logits_raw, logits, model.config.vocab_size * sizeof(floatX), cudaMemcpyDeviceToHost)); // convert to FP32 into cpu_logits (this does nothing useful if floatX == float) @@ -1837,6 +1791,14 @@ int main(int argc, char *argv[]) { // clip the gradient norm to a maximum value float grad_clip = 1.0f; float grad_scale = (grad_norm > grad_clip) ? grad_clip / grad_norm : 1.0f; + + // todo - hack - because the 1st step is now kinda useless due to FP8 absmax scaling not being ready + // todo - ideally should rerun this step so we don't "waste" the data without training on it + if (step == 0) { + step_learning_rate = 0.0f; + weight_decay = 0.0f; + } + gpt2_update(&model, step_learning_rate, 0.9f, 0.95f, 1e-8f, weight_decay, grad_scale, step+1, &multi_gpu_config); } cudaCheck(cudaEventRecord(end));