Skip to content

Commit 3c4cbdb

Browse files
committed
feat: move tensor allocation to ctor
1 parent 2355af6 commit 3c4cbdb

File tree

2 files changed

+13
-9
lines changed

2 files changed

+13
-9
lines changed

src/turbomind/layers/sampling_layers/GuidedDecodeMaskLayer.cc

Lines changed: 9 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,9 @@ namespace turbomind {
2222
template<typename T>
2323
GuidedDecodeMaskLayer<T>::GuidedDecodeMaskLayer(const BaseParam& param): BaseDynamicDecodeLayer{param}
2424
{
25+
const auto bitmask_size = xgrammar::GetBitmaskSize(vocab_size_padded_);
26+
bitmask_buf_ = {{max_batch_size_, bitmask_size}, kCPU};
27+
bitmask_ = {{max_batch_size_, bitmask_size}, kDEVICE};
2528
}
2629

2730
template<typename T>
@@ -42,16 +45,14 @@ void GuidedDecodeMaskLayer<T>::Forward(TensorMap& args)
4245
Tensor_<float> logits = args.at("logits");
4346
const ssize_t bsz = logits.shape(0);
4447

45-
FT_CHECK(bsz == matchers_.size());
48+
TM_CHECK(bsz == matchers_.size());
4649

47-
const auto bitmask_size = xgrammar::GetBitmaskSize(vocab_size_padded_);
48-
Tensor_<int32_t> bitmask{{bsz, bitmask_size}, kCPU};
49-
Tensor_<int32_t> bitmask_device{{bsz, bitmask_size}, kDEVICE};
50+
const auto bitmask_size = bitmask_buf_.shape(1);
5051
std::vector<int64_t> bitmask_shape = {bsz, bitmask_size};
5152

52-
DLTensor bitmask_dltensor{bitmask.data(),
53+
DLTensor bitmask_dltensor{bitmask_buf_.data(),
5354
DLDevice{kDLCPU, 0},
54-
bitmask.ndim(),
55+
bitmask_buf_.ndim(),
5556
xgrammar::GetBitmaskDLType(),
5657
bitmask_shape.data(),
5758
nullptr,
@@ -66,8 +67,8 @@ void GuidedDecodeMaskLayer<T>::Forward(TensorMap& args)
6667
}
6768

6869
if (need_apply) {
69-
Copy(bitmask, bitmask_device);
70-
ApplyTokenBitmaskInplace(logits, bitmask_device);
70+
Copy(bitmask_buf_, bitmask_);
71+
ApplyTokenBitmaskInplace(logits, bitmask_.slice(0, bsz));
7172
}
7273
}
7374

src/turbomind/layers/sampling_layers/GuidedDecodeMaskLayer.h

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -34,8 +34,11 @@ class GuidedDecodeMaskLayer: public BaseDynamicDecodeLayer {
3434
void Forward(TensorMap& args) override;
3535

3636
private:
37-
// host buffer
3837
std::vector<std::shared_ptr<xgrammar::GrammarMatcher>> matchers_;
38+
// host buffer
39+
Tensor_<int32_t> bitmask_buf_;
40+
// device buffer
41+
Tensor_<int32_t> bitmask_;
3942
};
4043

4144
} // namespace turbomind

0 commit comments

Comments
 (0)