|
22 | 22 | #include <cuda_fp16.h>
|
23 | 23 | #include <cuda_runtime.h>
|
24 | 24 |
|
| 25 | +#include "src/turbomind/core/context.h" |
25 | 26 | #include "src/turbomind/kernels/apply_token_bitmask_inplace_cuda.h"
|
26 | 27 | // clang-format on
|
27 | 28 |
|
@@ -140,27 +141,28 @@ void ApplyTokenBitmaskInplaceDispatchToBitsPerThread(T* __restrict__ logits,
|
140 | 141 | const int32_t num_blocks_per_row = CeilDiv(2048 / THREADS_PER_THREAD_BLOCK * 128, num_rows);
|
141 | 142 | const int32_t num_bits_per_thread = CeilDiv(vocab_size, THREADS_PER_THREAD_BLOCK * num_blocks_per_row);
|
142 | 143 |
|
143 |
| - const dim3 block(THREADS_PER_THREAD_BLOCK); |
| 144 | + const dim3 block(THREADS_PER_THREAD_BLOCK); |
| 145 | + const auto& stream = turbomind::core::Context::stream(); |
144 | 146 |
|
145 | 147 | if (num_bits_per_thread <= 4 && kAlignment <= 4) {
|
146 | 148 | const dim3 grid(CeilDiv(vocab_size, THREADS_PER_THREAD_BLOCK * 4), num_rows);
|
147 | 149 | LogitsBitmaskKernel<T, PackedT, 4>
|
148 |
| - <<<grid, block, 0>>>(logits, bitmask, indices, vocab_size, logits_stride, bitmask_stride); |
| 150 | + <<<grid, block, 0, stream.handle()>>>(logits, bitmask, indices, vocab_size, logits_stride, bitmask_stride); |
149 | 151 | }
|
150 | 152 | else if (num_bits_per_thread <= 8 && kAlignment <= 8) {
|
151 | 153 | const dim3 grid(CeilDiv(vocab_size, THREADS_PER_THREAD_BLOCK * 8), num_rows);
|
152 | 154 | LogitsBitmaskKernel<T, PackedT, 8>
|
153 |
| - <<<grid, block, 0>>>(logits, bitmask, indices, vocab_size, logits_stride, bitmask_stride); |
| 155 | + <<<grid, block, 0, stream.handle()>>>(logits, bitmask, indices, vocab_size, logits_stride, bitmask_stride); |
154 | 156 | }
|
155 | 157 | else if (num_bits_per_thread <= 16 && kAlignment <= 16) {
|
156 | 158 | const dim3 grid(CeilDiv(vocab_size, THREADS_PER_THREAD_BLOCK * 16), num_rows);
|
157 | 159 | LogitsBitmaskKernel<T, PackedT, 16>
|
158 |
| - <<<grid, block, 0>>>(logits, bitmask, indices, vocab_size, logits_stride, bitmask_stride); |
| 160 | + <<<grid, block, 0, stream.handle()>>>(logits, bitmask, indices, vocab_size, logits_stride, bitmask_stride); |
159 | 161 | }
|
160 | 162 | else {
|
161 | 163 | const dim3 grid(CeilDiv(vocab_size, THREADS_PER_THREAD_BLOCK * 32), num_rows);
|
162 | 164 | LogitsBitmaskKernel<T, PackedT, 32>
|
163 |
| - <<<grid, block, 0>>>(logits, bitmask, indices, vocab_size, logits_stride, bitmask_stride); |
| 165 | + <<<grid, block, 0, stream.handle()>>>(logits, bitmask, indices, vocab_size, logits_stride, bitmask_stride); |
164 | 166 | }
|
165 | 167 | }
|
166 | 168 |
|
|
0 commit comments