@@ -790,55 +790,56 @@ def test_hybrid_attention_mamba_tensor_shapes(monkeypatch):
790790 cache_dtype = "auto" ,
791791 )
792792 parallel_config = ParallelConfig ()
793- vllm_config = VllmConfig (
794- model_config = model_config ,
795- cache_config = cache_config ,
796- scheduler_config = scheduler_config ,
797- parallel_config = parallel_config ,
798- )
799-
800- layer_0 = "model.layers.0.self_attn.attn"
801- layer_1 = "model.layers.1.self_attn.attn"
802- layer_2 = "model.layers.2.mixer"
803- layer_3 = "model.layers.3.mixer"
804- layer_4 = "model.layers.4.mixer"
805- layer_5 = "model.layers.5.mixer"
806-
807- with set_current_vllm_config (vllm_config ), monkeypatch .context () as m :
808- m .setenv ("VLLM_ATTENTION_BACKEND" , "FLASHINFER" )
809- hf_config = vllm_config .model_config .hf_config
810- fwd_context = {}
811- for key in [layer_0 , layer_1 ]:
812- fwd_context [key ] = Attention (
813- num_heads = model_config .get_num_attention_heads (parallel_config ),
814- num_kv_heads = model_config .get_num_kv_heads (parallel_config ),
815- head_size = model_config .get_head_size (),
816- scale = 1.0 ,
817- prefix = key ,
818- )
819- for key in [layer_2 , layer_3 , layer_4 , layer_5 ]:
820- fwd_context [key ] = MambaMixer2 (
821- hidden_size = hf_config .hidden_size ,
822- ssm_state_size = hf_config .mamba_d_state ,
823- conv_kernel_size = hf_config .mamba_d_conv ,
824- intermediate_size = hf_config .mamba_expand * hf_config .hidden_size ,
825- use_conv_bias = hf_config .mamba_conv_bias ,
826- use_bias = hf_config .mamba_proj_bias ,
827- n_groups = hf_config .mamba_n_groups ,
828- num_heads = hf_config .mamba_n_heads ,
829- head_dim = hf_config .mamba_d_head ,
830- rms_norm_eps = hf_config .rms_norm_eps ,
831- activation = hf_config .hidden_act ,
832- cache_config = cache_config ,
833- model_config = model_config ,
834- prefix = key ,
835- )
836- # suppress var not used error
837- assert fwd_context is not None
838- vllm_ctx = vllm_config .compilation_config .static_forward_context
839-
840793 with monkeypatch .context () as m :
794+ # Attention backend should be set before creating VllmConfig because
795+ # VllmConfig will determine the kv block size based on the attention backend
841796 m .setenv ("VLLM_ATTENTION_BACKEND" , "FLASHINFER" )
797+ vllm_config = VllmConfig (
798+ model_config = model_config ,
799+ cache_config = cache_config ,
800+ scheduler_config = scheduler_config ,
801+ parallel_config = parallel_config ,
802+ )
803+
804+ layer_0 = "model.layers.0.self_attn.attn"
805+ layer_1 = "model.layers.1.self_attn.attn"
806+ layer_2 = "model.layers.2.mixer"
807+ layer_3 = "model.layers.3.mixer"
808+ layer_4 = "model.layers.4.mixer"
809+ layer_5 = "model.layers.5.mixer"
810+
811+ with set_current_vllm_config (vllm_config ):
812+ hf_config = vllm_config .model_config .hf_config
813+ fwd_context = {}
814+ for key in [layer_0 , layer_1 ]:
815+ fwd_context [key ] = Attention (
816+ num_heads = model_config .get_num_attention_heads (parallel_config ),
817+ num_kv_heads = model_config .get_num_kv_heads (parallel_config ),
818+ head_size = model_config .get_head_size (),
819+ scale = 1.0 ,
820+ prefix = key ,
821+ )
822+ for key in [layer_2 , layer_3 , layer_4 , layer_5 ]:
823+ fwd_context [key ] = MambaMixer2 (
824+ hidden_size = hf_config .hidden_size ,
825+ ssm_state_size = hf_config .mamba_d_state ,
826+ conv_kernel_size = hf_config .mamba_d_conv ,
827+ intermediate_size = hf_config .mamba_expand * hf_config .hidden_size ,
828+ use_conv_bias = hf_config .mamba_conv_bias ,
829+ use_bias = hf_config .mamba_proj_bias ,
830+ n_groups = hf_config .mamba_n_groups ,
831+ num_heads = hf_config .mamba_n_heads ,
832+ head_dim = hf_config .mamba_d_head ,
833+ rms_norm_eps = hf_config .rms_norm_eps ,
834+ activation = hf_config .hidden_act ,
835+ cache_config = cache_config ,
836+ model_config = model_config ,
837+ prefix = key ,
838+ )
839+ # suppress var not used error
840+ assert fwd_context is not None
841+ vllm_ctx = vllm_config .compilation_config .static_forward_context
842+
842843
843844 runner = GPUModelRunner (vllm_config , DEVICE )
844845 kv_cache_spec = runner .get_kv_cache_spec ()
0 commit comments