@@ -22,6 +22,9 @@ namespace turbomind {
22
22
template <typename T>
23
23
GuidedDecodeMaskLayer<T>::GuidedDecodeMaskLayer(const BaseParam& param): BaseDynamicDecodeLayer{param}
24
24
{
25
+ const auto bitmask_size = xgrammar::GetBitmaskSize (vocab_size_padded_);
26
+ bitmask_buf_ = {{max_batch_size_, bitmask_size}, kCPU };
27
+ bitmask_ = {{max_batch_size_, bitmask_size}, kDEVICE };
25
28
}
26
29
27
30
template <typename T>
@@ -42,16 +45,14 @@ void GuidedDecodeMaskLayer<T>::Forward(TensorMap& args)
42
45
Tensor_<float > logits = args.at (" logits" );
43
46
const ssize_t bsz = logits.shape (0 );
44
47
45
- FT_CHECK (bsz == matchers_.size ());
48
+ TM_CHECK (bsz == matchers_.size ());
46
49
47
- const auto bitmask_size = xgrammar::GetBitmaskSize (vocab_size_padded_);
48
- Tensor_<int32_t > bitmask{{bsz, bitmask_size}, kCPU };
49
- Tensor_<int32_t > bitmask_device{{bsz, bitmask_size}, kDEVICE };
50
+ const auto bitmask_size = bitmask_buf_.shape (1 );
50
51
std::vector<int64_t > bitmask_shape = {bsz, bitmask_size};
51
52
52
- DLTensor bitmask_dltensor{bitmask .data (),
53
+ DLTensor bitmask_dltensor{bitmask_buf_ .data (),
53
54
DLDevice{kDLCPU , 0 },
54
- bitmask .ndim (),
55
+ bitmask_buf_ .ndim (),
55
56
xgrammar::GetBitmaskDLType (),
56
57
bitmask_shape.data (),
57
58
nullptr ,
@@ -66,8 +67,8 @@ void GuidedDecodeMaskLayer<T>::Forward(TensorMap& args)
66
67
}
67
68
68
69
if (need_apply) {
69
- Copy (bitmask, bitmask_device );
70
- ApplyTokenBitmaskInplace (logits, bitmask_device );
70
+ Copy (bitmask_buf_, bitmask_ );
71
+ ApplyTokenBitmaskInplace (logits, bitmask_. slice ( 0 , bsz) );
71
72
}
72
73
}
73
74
0 commit comments