Skip to content

Commit d75e11e

Browse files
committed
test: add matrix grammar test
1 parent d562d8e commit d75e11e

File tree

7 files changed

+78
-48
lines changed

7 files changed

+78
-48
lines changed

.github/workflows/unit-test.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -59,7 +59,7 @@ jobs:
5959
python3 -m pip install torch==2.4.0 torchvision==0.19.0 --index-url https://download.pytorch.org/whl/cu118
6060
- name: Install lmdeploy
6161
run: |
62-
python3 -m pip install pynvml packaging protobuf transformers_stream_generator matplotlib
62+
python3 -m pip install pynvml packaging protobuf transformers_stream_generator matplotlib timm
6363
# manually install flash attn
6464
python3 -m pip install /root/packages/cu118/flash_attn-*.whl
6565
python3 -m pip install -r requirements_cuda.txt -r requirements/test.txt

CMakeLists.txt

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -82,7 +82,7 @@ FetchContent_MakeAvailable(yaml-cpp)
8282
FetchContent_Declare(
8383
xgrammar
8484
GIT_REPOSITORY https://github.com/mlc-ai/xgrammar.git
85-
GIT_TAG v0.1.21
85+
GIT_TAG v0.1.25
8686
GIT_SUBMODULES "3rdparty/dlpack"
8787
GIT_PROGRESS TRUE
8888
USES_TERMINAL_DOWNLOAD TRUE
@@ -94,7 +94,10 @@ if(NOT xgrammar_POPULATED)
9494
# Fetch the content using previously declared details
9595
FetchContent_Populate(xgrammar)
9696

97-
file(WRITE ${xgrammar_SOURCE_DIR}/config.cmake "set(XGRAMMAR_BUILD_PYTHON_BINDINGS OFF)")
97+
file(WRITE ${xgrammar_SOURCE_DIR}/config.cmake "set(XGRAMMAR_BUILD_PYTHON_BINDINGS OFF)\n")
98+
if(NOT MSVC)
99+
file(APPEND ${xgrammar_SOURCE_DIR}/config.cmake "set(CMAKE_CXX_FLAGS \"-Wno-error\")\n")
100+
endif()
98101

99102
# Bring the populated content into the build
100103
add_subdirectory(${xgrammar_SOURCE_DIR} ${xgrammar_BINARY_DIR})

requirements/runtime_cuda.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@ fire
55
mmengine-lite
66
numpy<2.0.0
77
openai
8-
outlines
8+
outlines<0.1.0
99
partial_json_parser
1010
peft<=0.14.0
1111
pillow

requirements/runtime_rocm.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@ fire
55
mmengine-lite
66
numpy<2.0.0
77
openai
8-
outlines
8+
outlines<0.1.0
99
partial_json_parser
1010
peft<=0.14.0
1111
pillow

src/turbomind/kernels/apply_token_bitmask_inplace_cuda.cu

Lines changed: 21 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -214,19 +214,34 @@ void ApplyTokenBitmaskInplace(Tensor logits, Tensor bitmask, std::optional<Tenso
214214

215215
switch (logits.dtype()) {
216216
case kFloat32: {
217-
ApplyTokenBitmaskInplaceDispatchToPackedT(
218-
logits.data<float>(), bitmask.data<int32_t>(), indices_ptr, vocab_size, 0, 0, num_rows);
217+
ApplyTokenBitmaskInplaceDispatchToPackedT(logits.data<float>(),
218+
bitmask.data<int32_t>(),
219+
indices_ptr,
220+
vocab_size,
221+
logits.stride(0),
222+
bitmask.stride(0),
223+
num_rows);
219224
break;
220225
}
221226
case kFloat16: {
222-
ApplyTokenBitmaskInplaceDispatchToPackedT(
223-
logits.data<half_t>(), bitmask.data<int32_t>(), indices_ptr, vocab_size, 0, 0, num_rows);
227+
ApplyTokenBitmaskInplaceDispatchToPackedT(logits.data<half_t>(),
228+
bitmask.data<int32_t>(),
229+
indices_ptr,
230+
vocab_size,
231+
logits.stride(0),
232+
bitmask.stride(0),
233+
num_rows);
224234
break;
225235
}
226236
#if __CUDA_ARCH__ >= 800
227237
case kBfloat16: {
228-
ApplyTokenBitmaskInplaceDispatchToPackedT(
229-
logits.data<bfloat16_t>(), bitmask.data<int32_t>(), indices_ptr, vocab_size, 0, 0, num_rows);
238+
ApplyTokenBitmaskInplaceDispatchToPackedT(logits.data<bfloat16_t>(),
239+
bitmask.data<int32_t>(),
240+
indices_ptr,
241+
vocab_size,
242+
logits.stride(0),
243+
bitmask.stride(0),
244+
num_rows);
230245
break;
231246
}
232247
#endif

src/turbomind/python/xgrammar_bind.cpp

Lines changed: 6 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -107,15 +107,7 @@ PYBIND11_MODULE(_xgrammar, m)
107107
return TokenizerInfo::FromVocabAndMetadata(CommonEncodedVocabType(encoded_vocab), metadata);
108108
})
109109

110-
.def_static("_detect_metadata_from_hf", &TokenizerInfo::DetectMetadataFromHF)
111-
112-
.def("serialize_json", &TokenizerInfo::SerializeJSON)
113-
114-
.def_static(
115-
"deserialize_json",
116-
[](const std::string& str, const py::typing::List<std::variant<std::string, py::bytes>>& encoded_vocab) {
117-
return TokenizerInfo::DeserializeJSON(str, CommonEncodedVocabType(encoded_vocab));
118-
});
110+
.def_static("_detect_metadata_from_hf", &TokenizerInfo::DetectMetadataFromHF);
119111

120112
py::class_<CompiledGrammar>(m, "CompiledGrammar");
121113

@@ -130,10 +122,11 @@ PYBIND11_MODULE(_xgrammar, m)
130122
&GrammarCompiler::CompileJSONSchema,
131123
py::call_guard<py::gil_scoped_release>(),
132124
py::arg("schema"),
133-
py::arg("any_whitespace") = false,
134-
py::arg("indent") = py::none(),
135-
py::arg("separators") = py::none(),
136-
py::arg("strict_mode") = true)
125+
py::arg("any_whitespace") = false,
126+
py::arg("indent") = py::none(),
127+
py::arg("separators") = py::none(),
128+
py::arg("strict_mode") = true,
129+
py::arg("max_whitespace_cnt") = py::none())
137130
.def("compile_regex",
138131
&GrammarCompiler::CompileRegex,
139132
py::call_guard<py::gil_scoped_release>(),

tests/test_lmdeploy/test_grammar.py

Lines changed: 43 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -4,20 +4,19 @@
44
from jsonschema import validate
55

66
from lmdeploy import pipeline
7-
from lmdeploy.messages import GenerationConfig, TurbomindEngineConfig
7+
from lmdeploy.messages import GenerationConfig, PytorchEngineConfig, TurbomindEngineConfig
88

9+
MODEL_IDS = [
10+
'Qwen/Qwen3-0.6B',
11+
'OpenGVLab/InternVL3_5-1B',
12+
]
913

10-
@pytest.fixture(scope='module')
11-
def tiny_model_id():
12-
return 'internlm/internlm2_5-1_8b'
14+
BACKEND_FACTORIES = [
15+
('tm', lambda: TurbomindEngineConfig(max_batch_size=2, session_len=1024)),
16+
('pt', lambda: PytorchEngineConfig(max_batch_size=1, session_len=1024)),
17+
]
1318

14-
15-
@pytest.fixture(scope='module')
16-
def tmp_workspace(tmp_path_factory):
17-
return tmp_path_factory.mktemp('tm_workspace')
18-
19-
20-
guide = {
19+
GUIDE_SCHEMA = {
2120
'type': 'object',
2221
'properties': {
2322
'name': {
@@ -29,7 +28,8 @@ def tmp_workspace(tmp_path_factory):
2928
'type': 'string',
3029
'maxLength': 10
3130
},
32-
'minItems': 3
31+
'minItems': 3,
32+
'maxItems': 10,
3333
},
3434
'work history': {
3535
'type': 'array',
@@ -41,20 +41,39 @@ def tmp_workspace(tmp_path_factory):
4141
},
4242
'duration': {
4343
'type': 'string'
44-
}
44+
},
4545
},
46-
'required': ['company']
47-
}
48-
}
46+
'required': ['company'],
47+
},
48+
},
4949
},
50-
'required': ['name', 'skills', 'work history']
50+
'required': ['name', 'skills', 'work history'],
5151
}
5252

5353

54-
def test_tm_guided_pipeline(tiny_model_id):
55-
pipe = pipeline(tiny_model_id,
56-
backend_config=TurbomindEngineConfig(max_batch_size=1, session_len=1024),
57-
log_level='INFO')
58-
gen_config = GenerationConfig(response_format=dict(type='json_schema', json_schema=dict(name='test', schema=guide)))
59-
response = pipe(['Make a self introduction please.'], gen_config=gen_config)
60-
validate(instance=json.loads(response[0].text), schema=guide)
54+
@pytest.mark.parametrize('model_id', MODEL_IDS)
55+
@pytest.mark.parametrize('backend_name,backend_factory', BACKEND_FACTORIES)
56+
@pytest.mark.parametrize('enable_guide', [True, False])
57+
def test_guided_matrix(model_id, backend_name, backend_factory, enable_guide):
58+
pipe = pipeline(
59+
model_id,
60+
backend_config=backend_factory(),
61+
log_level='INFO',
62+
)
63+
64+
try:
65+
if enable_guide:
66+
gen_config = GenerationConfig(response_format=dict(
67+
type='json_schema',
68+
json_schema=dict(name='test', schema=GUIDE_SCHEMA),
69+
), )
70+
else:
71+
gen_config = GenerationConfig()
72+
73+
response = pipe(['Make a self introduction please.'] * 3, gen_config=gen_config)
74+
assert response and response[0].text
75+
76+
if enable_guide:
77+
validate(instance=json.loads(response[0].text), schema=GUIDE_SCHEMA)
78+
finally:
79+
pipe.close()

0 commit comments

Comments
 (0)