Skip to content

Commit e616e00

Browse files
committed
added missing stream argument for repkv_backward
1 parent 8a1893e commit e616e00

File tree

2 files changed

+3
-3
lines changed

2 files changed

+3
-3
lines changed

llmc/repkv.cuh

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -111,11 +111,11 @@ void repkv_forward(floatX* out, const floatX* inp, int B, int T, int NH, int NH_
111111
}
112112

113113
void repkv_backward(floatX* dinp, const floatX* dout,
114-
const int B, const int T, const int NH, const int NH_KV, const int d) {
114+
const int B, const int T, const int NH, const int NH_KV, const int d, cudaStream_t stream) {
115115
const int block_size = 128;
116116
int total_threads = B * T * (3 * NH) * d;
117117
int num_blocks = CEIL_DIV(total_threads, block_size);
118118
int replicate_factor = NH / NH_KV;
119-
repkv_backward_kernel1<<<num_blocks, block_size>>>(dinp, dout, B, T, NH, replicate_factor, d);
119+
repkv_backward_kernel1<<<num_blocks, block_size0, stream>>>(dinp, dout, B, T, NH, replicate_factor, d);
120120
cudaCheck(cudaGetLastError());
121121
}

train_llama3.cu

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -922,7 +922,7 @@ void llama3_backward_and_reduce(LLama3 *model, int* inputs, const int* targets,
922922
floatX* buffer_b = l_fch_pre_gelu; // this is B x T x 4C, so even larger than what we need
923923
attention_backward(dl_bt4c, buffer_b, scratchX, buffer_a, dl_btc, l_qkvr, l_att, B, T, C, NH, main_stream);
924924
// backward repkv (use scratchX as gradient buffer here)
925-
repkv_backward(dl_bt4c2, dl_bt4c, B, T, NH, n_kv_head, hd);
925+
repkv_backward(dl_bt4c2, dl_bt4c, B, T, NH, n_kv_head, hd, main_stream);
926926
#endif
927927
// backward rope (this can be done in-place)
928928
rope_backward_inplace(dl_bt4c2, dl_bt4c2, model->freqs_cis, B, T, NH, n_kv_head, hd, main_stream);

0 commit comments

Comments
 (0)