Skip to content

Commit 8bcbfff

Browse files
committed
temp: enable debug_print
1 parent d8bb401 commit 8bcbfff

File tree

2 files changed

+12
-8
lines changed

2 files changed

+12
-8
lines changed

src/turbomind/kernels/apply_token_bitmask_inplace_cuda.cu

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
#include <cuda_fp16.h>
2323
#include <cuda_runtime.h>
2424

25+
#include "src/turbomind/core/context.h"
2526
#include "src/turbomind/kernels/apply_token_bitmask_inplace_cuda.h"
2627
// clang-format on
2728

@@ -140,27 +141,28 @@ void ApplyTokenBitmaskInplaceDispatchToBitsPerThread(T* __restrict__ logits,
140141
const int32_t num_blocks_per_row = CeilDiv(2048 / THREADS_PER_THREAD_BLOCK * 128, num_rows);
141142
const int32_t num_bits_per_thread = CeilDiv(vocab_size, THREADS_PER_THREAD_BLOCK * num_blocks_per_row);
142143

143-
const dim3 block(THREADS_PER_THREAD_BLOCK);
144+
const dim3 block(THREADS_PER_THREAD_BLOCK);
145+
const auto& stream = turbomind::core::Context::stream();
144146

145147
if (num_bits_per_thread <= 4 && kAlignment <= 4) {
146148
const dim3 grid(CeilDiv(vocab_size, THREADS_PER_THREAD_BLOCK * 4), num_rows);
147149
LogitsBitmaskKernel<T, PackedT, 4>
148-
<<<grid, block, 0>>>(logits, bitmask, indices, vocab_size, logits_stride, bitmask_stride);
150+
<<<grid, block, 0, stream.handle()>>>(logits, bitmask, indices, vocab_size, logits_stride, bitmask_stride);
149151
}
150152
else if (num_bits_per_thread <= 8 && kAlignment <= 8) {
151153
const dim3 grid(CeilDiv(vocab_size, THREADS_PER_THREAD_BLOCK * 8), num_rows);
152154
LogitsBitmaskKernel<T, PackedT, 8>
153-
<<<grid, block, 0>>>(logits, bitmask, indices, vocab_size, logits_stride, bitmask_stride);
155+
<<<grid, block, 0, stream.handle()>>>(logits, bitmask, indices, vocab_size, logits_stride, bitmask_stride);
154156
}
155157
else if (num_bits_per_thread <= 16 && kAlignment <= 16) {
156158
const dim3 grid(CeilDiv(vocab_size, THREADS_PER_THREAD_BLOCK * 16), num_rows);
157159
LogitsBitmaskKernel<T, PackedT, 16>
158-
<<<grid, block, 0>>>(logits, bitmask, indices, vocab_size, logits_stride, bitmask_stride);
160+
<<<grid, block, 0, stream.handle()>>>(logits, bitmask, indices, vocab_size, logits_stride, bitmask_stride);
159161
}
160162
else {
161163
const dim3 grid(CeilDiv(vocab_size, THREADS_PER_THREAD_BLOCK * 32), num_rows);
162164
LogitsBitmaskKernel<T, PackedT, 32>
163-
<<<grid, block, 0>>>(logits, bitmask, indices, vocab_size, logits_stride, bitmask_stride);
165+
<<<grid, block, 0, stream.handle()>>>(logits, bitmask, indices, vocab_size, logits_stride, bitmask_stride);
164166
}
165167
}
166168

src/turbomind/layers/sampling_layers/GuidedDecodeMaskLayer.cc

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -47,13 +47,13 @@ void GuidedDecodeMaskLayer<T>::Forward(TensorMap& args)
4747
const auto bitmask_size = xgrammar::GetBitmaskSize(vocab_size_padded_);
4848
Tensor_<int32_t> bitmask{{bsz, bitmask_size}, kCPU};
4949
Tensor_<int32_t> bitmask_device{{bsz, bitmask_size}, kDEVICE};
50-
std::vector<int64_t> bitmap_shape = {bsz, bitmask_size};
50+
std::vector<int64_t> bitmask_shape = {bsz, bitmask_size};
5151

5252
DLTensor bitmask_dltensor{bitmask.data(),
5353
DLDevice{kDLCPU, 0},
54-
static_cast<int32_t>(bitmap_shape.size()),
54+
bitmask.ndim(),
5555
xgrammar::GetBitmaskDLType(),
56-
bitmap_shape.data(),
56+
bitmask_shape.data(),
5757
nullptr,
5858
0};
5959
bool need_apply = false;
@@ -67,6 +67,8 @@ void GuidedDecodeMaskLayer<T>::Forward(TensorMap& args)
6767

6868
if (need_apply) {
6969
Copy(bitmask, bitmask_device);
70+
71+
// cudaDeviceSynchronize();
7072
ApplyTokenBitmaskInplace(logits, bitmask_device);
7173
}
7274
}

0 commit comments

Comments
 (0)