Skip to content

Commit 6e4d374

Browse files
committed
fix test
Signed-off-by: Vadim Gimpelson <[email protected]>
1 parent 3a51814 commit 6e4d374

File tree

1 file changed

+48
-47
lines changed

1 file changed

+48
-47
lines changed

tests/v1/worker/test_gpu_model_runner.py

Lines changed: 48 additions & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -790,55 +790,56 @@ def test_hybrid_attention_mamba_tensor_shapes(monkeypatch):
790790
cache_dtype="auto",
791791
)
792792
parallel_config = ParallelConfig()
793-
vllm_config = VllmConfig(
794-
model_config=model_config,
795-
cache_config=cache_config,
796-
scheduler_config=scheduler_config,
797-
parallel_config=parallel_config,
798-
)
799-
800-
layer_0 = "model.layers.0.self_attn.attn"
801-
layer_1 = "model.layers.1.self_attn.attn"
802-
layer_2 = "model.layers.2.mixer"
803-
layer_3 = "model.layers.3.mixer"
804-
layer_4 = "model.layers.4.mixer"
805-
layer_5 = "model.layers.5.mixer"
806-
807-
with set_current_vllm_config(vllm_config), monkeypatch.context() as m:
808-
m.setenv("VLLM_ATTENTION_BACKEND", "FLASHINFER")
809-
hf_config = vllm_config.model_config.hf_config
810-
fwd_context = {}
811-
for key in [layer_0, layer_1]:
812-
fwd_context[key] = Attention(
813-
num_heads=model_config.get_num_attention_heads(parallel_config),
814-
num_kv_heads=model_config.get_num_kv_heads(parallel_config),
815-
head_size=model_config.get_head_size(),
816-
scale=1.0,
817-
prefix=key,
818-
)
819-
for key in [layer_2, layer_3, layer_4, layer_5]:
820-
fwd_context[key] = MambaMixer2(
821-
hidden_size=hf_config.hidden_size,
822-
ssm_state_size=hf_config.mamba_d_state,
823-
conv_kernel_size=hf_config.mamba_d_conv,
824-
intermediate_size=hf_config.mamba_expand * hf_config.hidden_size,
825-
use_conv_bias=hf_config.mamba_conv_bias,
826-
use_bias=hf_config.mamba_proj_bias,
827-
n_groups=hf_config.mamba_n_groups,
828-
num_heads=hf_config.mamba_n_heads,
829-
head_dim=hf_config.mamba_d_head,
830-
rms_norm_eps=hf_config.rms_norm_eps,
831-
activation=hf_config.hidden_act,
832-
cache_config=cache_config,
833-
model_config=model_config,
834-
prefix=key,
835-
)
836-
# suppress var not used error
837-
assert fwd_context is not None
838-
vllm_ctx = vllm_config.compilation_config.static_forward_context
839-
840793
with monkeypatch.context() as m:
794+
# Attention backend should be set before creating VllmConfig because
795+
# VllmConfig will determine the kv block size based on the attention backend
841796
m.setenv("VLLM_ATTENTION_BACKEND", "FLASHINFER")
797+
vllm_config = VllmConfig(
798+
model_config=model_config,
799+
cache_config=cache_config,
800+
scheduler_config=scheduler_config,
801+
parallel_config=parallel_config,
802+
)
803+
804+
layer_0 = "model.layers.0.self_attn.attn"
805+
layer_1 = "model.layers.1.self_attn.attn"
806+
layer_2 = "model.layers.2.mixer"
807+
layer_3 = "model.layers.3.mixer"
808+
layer_4 = "model.layers.4.mixer"
809+
layer_5 = "model.layers.5.mixer"
810+
811+
with set_current_vllm_config(vllm_config):
812+
hf_config = vllm_config.model_config.hf_config
813+
fwd_context = {}
814+
for key in [layer_0, layer_1]:
815+
fwd_context[key] = Attention(
816+
num_heads=model_config.get_num_attention_heads(parallel_config),
817+
num_kv_heads=model_config.get_num_kv_heads(parallel_config),
818+
head_size=model_config.get_head_size(),
819+
scale=1.0,
820+
prefix=key,
821+
)
822+
for key in [layer_2, layer_3, layer_4, layer_5]:
823+
fwd_context[key] = MambaMixer2(
824+
hidden_size=hf_config.hidden_size,
825+
ssm_state_size=hf_config.mamba_d_state,
826+
conv_kernel_size=hf_config.mamba_d_conv,
827+
intermediate_size=hf_config.mamba_expand * hf_config.hidden_size,
828+
use_conv_bias=hf_config.mamba_conv_bias,
829+
use_bias=hf_config.mamba_proj_bias,
830+
n_groups=hf_config.mamba_n_groups,
831+
num_heads=hf_config.mamba_n_heads,
832+
head_dim=hf_config.mamba_d_head,
833+
rms_norm_eps=hf_config.rms_norm_eps,
834+
activation=hf_config.hidden_act,
835+
cache_config=cache_config,
836+
model_config=model_config,
837+
prefix=key,
838+
)
839+
# suppress var not used error
840+
assert fwd_context is not None
841+
vllm_ctx = vllm_config.compilation_config.static_forward_context
842+
842843

843844
runner = GPUModelRunner(vllm_config, DEVICE)
844845
kv_cache_spec = runner.get_kv_cache_spec()

0 commit comments

Comments
 (0)