2525from vllm .platforms import current_platform
2626from vllm .triton_utils import tl , triton
2727from vllm .utils import cdiv , is_pin_memory_available
28- from vllm .utils .flashinfer import (supports_trtllm_attention ,
28+ from vllm .utils .flashinfer import (flashinfer_disable_q_quantization ,
29+ supports_trtllm_attention ,
2930 use_trtllm_attention )
3031from vllm .v1 .attention .backends .flash_attn import use_cascade_attention
3132# yapf conflicts with isort for this block
4849logger = init_logger (__name__ )
4950
5051
51- class FlashInferBackend (AttentionBackend ):
52+ @triton .jit
53+ def _trtllm_prefill_attn_kvfp8_dequant (
54+ kv_cache_ptr ,
55+ block_tables_prefill_ptr ,
56+ block_table_stride ,
57+ mock_kv_cache_ptr ,
58+ k_scale_ptr ,
59+ v_scale_ptr ,
60+ K_CACHE_STRIDE : tl .constexpr ,
61+ KV_CACHE_STRIDE : tl .constexpr ,
62+ ):
63+ batch_idx = tl .program_id (0 ).to (tl .int64 )
64+ mock_block_table_idx = tl .program_id (1 ).to (tl .int64 )
65+ orig_page_num = tl .load (block_tables_prefill_ptr +
66+ batch_idx * block_table_stride +
67+ mock_block_table_idx ).to (tl .int64 )
68+ if orig_page_num <= 0 :
69+ return
70+ dequant_dtype = mock_kv_cache_ptr .dtype .element_ty
71+
72+ # Dequantize K
73+ k_scale_val = tl .load (k_scale_ptr )
74+ offset = orig_page_num * KV_CACHE_STRIDE + tl .arange (0 , K_CACHE_STRIDE )
75+ fp8_vals = tl .load (kv_cache_ptr + offset )
76+ dequantized_vals = fp8_vals .to (tl .float32 ) * k_scale_val
77+ mock_cache_offset = (batch_idx * block_table_stride + mock_block_table_idx
78+ + 1 ) * KV_CACHE_STRIDE + tl .arange (0 , K_CACHE_STRIDE )
79+ dequantized_vals = dequantized_vals .to (dequant_dtype )
80+ tl .store (mock_kv_cache_ptr + mock_cache_offset , dequantized_vals )
81+
82+ # Dequantize V
83+ v_scale_val = tl .load (v_scale_ptr )
84+ offset = (orig_page_num * KV_CACHE_STRIDE + K_CACHE_STRIDE +
85+ tl .arange (0 , K_CACHE_STRIDE ))
86+ fp8_vals = tl .load (kv_cache_ptr + offset )
87+ dequantized_vals = fp8_vals .to (tl .float32 ) * v_scale_val
88+ mock_cache_offset = (
89+ (batch_idx * block_table_stride + mock_block_table_idx + 1 ) *
90+ KV_CACHE_STRIDE + K_CACHE_STRIDE + tl .arange (0 , K_CACHE_STRIDE ))
91+ dequantized_vals = dequantized_vals .to (dequant_dtype )
92+ tl .store (mock_kv_cache_ptr + mock_cache_offset , dequantized_vals )
93+
94+
95+ def trtllm_prefill_attn_kvfp8_dequant (
96+ kv_cache : torch .Tensor ,
97+ block_tables_prefill : torch .Tensor ,
98+ k_scale : torch .Tensor ,
99+ v_scale : torch .Tensor ,
100+ dequant_dtype : torch .dtype ,
101+ ) -> tuple [torch .Tensor , torch .Tensor ]:
102+ batch_size , num_of_page_per_token = block_tables_prefill .shape
103+ s = kv_cache .shape
104+ assert s [1 ] == 2
105+ assert dequant_dtype in (torch .bfloat16 , torch .float16 )
106+ k_cache_stride = s [2 ] * s [3 ] * s [4 ]
107+ kv_cache_stride = k_cache_stride * s [1 ]
108+ new_s = (batch_size * num_of_page_per_token + 1 , s [1 ], s [2 ], s [3 ], s [4 ])
109+ # mock kv cache contains just the pages needed by this prefill
110+ mock_kv_cache = torch .empty (new_s ,
111+ dtype = dequant_dtype ,
112+ device = kv_cache .device )
113+ # we simply sequentially index the pages needed by this prefill
114+ mock_block_table = torch .arange (
115+ start = 1 ,
116+ end = batch_size * num_of_page_per_token + 1 ,
117+ dtype = torch .int32 ,
118+ device = block_tables_prefill .device ,
119+ ).reshape (batch_size , num_of_page_per_token )
120+ grid = (batch_size , num_of_page_per_token )
121+ _trtllm_prefill_attn_kvfp8_dequant [grid ](
122+ kv_cache ,
123+ block_tables_prefill ,
124+ num_of_page_per_token ,
125+ mock_kv_cache ,
126+ k_scale ,
127+ v_scale ,
128+ k_cache_stride ,
129+ kv_cache_stride ,
130+ )
131+ return mock_kv_cache , mock_block_table
132+
52133
134+ class FlashInferBackend (AttentionBackend ):
53135 accept_output_buffer : bool = True
54136
55137 @classmethod
@@ -122,7 +204,6 @@ def get_fp8_dtype_for_flashinfer(kv_cache_dtype: str) -> torch.dtype:
122204
123205@dataclass
124206class FlashInferMetadata :
125-
126207 num_actual_tokens : int # Number of tokens excluding padding.
127208
128209 # The data type of the query
@@ -175,8 +256,8 @@ def __init__(self, kv_cache_spec: AttentionSpec, layer_names: list[str],
175256 self .kv_cache_spec .block_size )
176257 max_num_reqs = vllm_config .scheduler_config .max_num_seqs
177258 max_num_pages = max_num_reqs * max_num_pages_per_req
178- self .enable_cuda_graph = self .compilation_config .cudagraph_mode .\
179- decode_mode () == CUDAGraphMode .FULL
259+ self .enable_cuda_graph = ( self .compilation_config .cudagraph_mode .\
260+ decode_mode () == CUDAGraphMode .FULL )
180261 if self .enable_cuda_graph :
181262 # For full cudagraph capture, one `decode_wrapper` for each batch
182263 # size is needed for FlashInfer.
@@ -201,7 +282,8 @@ def __init__(self, kv_cache_spec: AttentionSpec, layer_names: list[str],
201282 assert self .kv_cache_spec .dtype == self .model_config .dtype
202283 self .kv_cache_dtype = self .kv_cache_spec .dtype
203284
204- if supports_trtllm_attention ()[0 ]:
285+ if supports_trtllm_attention ()[0 ] and \
286+ not flashinfer_disable_q_quantization ():
205287 self .q_data_type = self .kv_cache_dtype
206288 else :
207289 self .q_data_type = self .model_config .dtype
@@ -795,11 +877,29 @@ def forward(
795877 assert self .o_sf_scale is None
796878 out = output [num_decode_tokens :]
797879
880+ if attn_metadata .q_data_type != FP8_DTYPE \
881+ and self .kv_cache_dtype .startswith ("fp8" ):
882+ # TRTLLM prefill attention does not support BF16 Q
883+ # and fp8 kv cache. So to enable prefill attention
884+ # with fp8 kv cache, we can construct a mock block
885+ # and mock kv cache with BF16 KV involved in the prefill
886+ mock_kv_cache , mock_block_table = (
887+ trtllm_prefill_attn_kvfp8_dequant (
888+ kv_cache_permute ,
889+ block_tables_prefill ,
890+ layer ._k_scale ,
891+ layer ._v_scale ,
892+ attn_metadata .q_data_type ,
893+ ))
894+ else :
895+ mock_kv_cache = kv_cache_permute
896+ mock_block_table = block_tables_prefill
897+
798898 trtllm_batch_context_with_kv_cache (
799899 query = prefill_query ,
800- kv_cache = kv_cache_permute ,
900+ kv_cache = mock_kv_cache ,
801901 workspace_buffer = workspace_buffer ,
802- block_tables = block_tables_prefill ,
902+ block_tables = mock_block_table ,
803903 seq_lens = seq_lens_prefill ,
804904 max_q_len = attn_metadata .max_q_len ,
805905 max_kv_len = attn_metadata .max_seq_len ,
@@ -837,7 +937,7 @@ def forward(
837937 decode_query = decode_query .contiguous ()
838938 workspace_buffer = decode_wrapper ._float_workspace_buffer
839939 block_tables_decode = attn_metadata .\
840- block_table_tensor [:num_decode_tokens ]
940+ block_table_tensor [:num_decode_tokens ]
841941 seq_lens_decode = attn_metadata .seq_lens [:num_decode_tokens ]
842942
843943 # This path needs to be enabled with VLLM_KV_CACHE_LAYOUT = HND
0 commit comments