Skip to content

Commit 2c18726

Browse files
committed
fa3 cudagraph
1 parent 465b533 commit 2c18726

File tree

6 files changed

+150
-49
lines changed

6 files changed

+150
-49
lines changed

lmdeploy/pytorch/backends/cuda/attention.py

Lines changed: 33 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,9 @@ class TritonAttentionMetadata(AttentionMetadata):
4242
num_splits: torch.Tensor = None
4343
cu_seqlens_q: torch.Tensor = None
4444
cu_seqlens_k: torch.Tensor = None
45+
# flash attn
46+
scheduler_metadata: torch.Tensor = None
47+
max_kv_seqlen: int = None
4548

4649

4750
def _cdiv(a, b):
@@ -477,42 +480,41 @@ def forward(
477480
v_scales_zeros=v_scales_zeros,
478481
quant_policy=quant_policy,
479482
)
480-
# sliding_window = (-1, -1) if self.sliding_window is None else self.sliding_window
481-
# if isinstance(sliding_window, int):
482-
# sliding_window = (sliding_window, sliding_window)
483-
# attn_output = self.flash_attn_with_kvcache_v3(
484-
# query,
485-
# k_cache,
486-
# v_cache,
487-
# cache_seqlens=attn_metadata.kv_seqlens.to(torch.int32),
488-
# cu_seqlens_q=attn_metadata.cu_seqlens_q,
489-
# cu_seqlens_k_new=attn_metadata.cu_seqlens_k,
490-
# max_seqlen_q=max_q_seqlen,
491-
# page_table=block_offsets,
492-
# softmax_scale=self.scale,
493-
# causal=self.causal,
494-
# window_size=sliding_window,
495-
# softcap=-1.0 if self.logit_softcapping is None else self.logit_softcapping,
496-
# )
497-
# return attn_output
498483
if is_decoding:
499-
q_shape = query.shape
500-
o_shape = q_shape[:-1] + (self.v_head_size, )
501-
attn_output = query.new_empty(o_shape)
502-
self.paged_attention_fwd(
484+
sliding_window = (-1, -1) if self.sliding_window is None else self.sliding_window
485+
if isinstance(sliding_window, int):
486+
sliding_window = (sliding_window, sliding_window)
487+
query = query.unflatten(0, (-1, max_q_seqlen))
488+
attn_output = self.flash_attn_with_kvcache_v3(
503489
query,
504490
k_cache,
505491
v_cache,
506-
attn_output,
507-
block_offsets,
508-
kv_seqlens=kv_seqlens,
509-
k_scales_zeros=k_scales_zeros,
510-
v_scales_zeros=v_scales_zeros,
511-
quant_policy=quant_policy,
512-
window_size=self.sliding_window,
513-
sm_scale=self.scale,
514-
logit_softcapping=self.logit_softcapping,
492+
cache_seqlens=attn_metadata.kv_seqlens.to(torch.int32),
493+
max_seqlen_q=max_q_seqlen,
494+
scheduler_metadata=attn_metadata.scheduler_metadata,
495+
page_table=block_offsets,
496+
softmax_scale=self.scale,
497+
causal=self.causal,
498+
window_size=sliding_window,
499+
softcap=-1.0 if self.logit_softcapping is None else self.logit_softcapping,
515500
)
501+
# q_shape = query.shape
502+
# o_shape = q_shape[:-1] + (self.v_head_size, )
503+
# attn_output = query.new_empty(o_shape)
504+
# self.paged_attention_fwd(
505+
# query,
506+
# k_cache,
507+
# v_cache,
508+
# attn_output,
509+
# block_offsets,
510+
# kv_seqlens=kv_seqlens,
511+
# k_scales_zeros=k_scales_zeros,
512+
# v_scales_zeros=v_scales_zeros,
513+
# quant_policy=quant_policy,
514+
# window_size=self.sliding_window,
515+
# sm_scale=self.scale,
516+
# logit_softcapping=self.logit_softcapping,
517+
# )
516518
else:
517519
# sliding_window = (-1, -1) if self.sliding_window is None else self.sliding_window
518520
# if isinstance(sliding_window, int):

lmdeploy/pytorch/backends/cuda/op_backend.py

Lines changed: 70 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
# Copyright (c) OpenMMLab. All rights reserved.
2-
from typing import Tuple
2+
from typing import Optional, Tuple
33

44
import torch
55

@@ -19,6 +19,46 @@ def _get_meta_flashmla(kv_seqlens, num_attention_heads):
1919
return tile_scheduler_metadata, num_splits
2020

2121

22+
def _get_meta_flashattn(
23+
batch_size: int,
24+
max_seqlen_q: int,
25+
max_seqlen_k: int,
26+
num_heads_q: int,
27+
num_heads_kv: int,
28+
headdim: int,
29+
cache_seqlens: torch.Tensor,
30+
qkv_dtype=torch.bfloat16,
31+
headdim_v=None,
32+
cu_seqlens_q: Optional[torch.Tensor] = None,
33+
cu_seqlens_k_new: Optional[torch.Tensor] = None,
34+
page_size: Optional[int] = None,
35+
causal=True,
36+
window_size=(-1, -1), # -1 means infinite context window
37+
num_splits=0,
38+
):
39+
"""Get scheduler metadata for flash attn."""
40+
from flash_attn_interface import get_scheduler_metadata
41+
42+
metadata = get_scheduler_metadata(
43+
batch_size,
44+
max_seqlen_q,
45+
max_seqlen_k,
46+
num_heads_q,
47+
num_heads_kv,
48+
headdim,
49+
cache_seqlens,
50+
qkv_dtype=qkv_dtype,
51+
headdim_v=headdim_v,
52+
cu_seqlens_q=cu_seqlens_q,
53+
cu_seqlens_k_new=cu_seqlens_k_new,
54+
page_size=page_size,
55+
causal=causal,
56+
window_size=window_size,
57+
num_splits=num_splits,
58+
)
59+
return metadata
60+
61+
2262
class CudaOpsBackend(DefaultOpsBackend):
2363
"""Cuda layer backend."""
2464

@@ -121,6 +161,28 @@ def update_meta_flashmla(cls, attn_metadata, num_attention_heads):
121161
if attn_metadata.block_offsets.dtype != torch.int32:
122162
attn_metadata.block_offsets = attn_metadata.block_offsets.to(torch.int32)
123163

164+
@classmethod
165+
def update_meta_flashattn(cls, attn_metadata, step_context):
166+
batch_size = attn_metadata.q_seqlens.size(0)
167+
max_seqlen_q = step_context.input_ids.size(1) // batch_size
168+
block_size = step_context.kv_caches[0][0].size(1)
169+
window_size = (step_context.model_config.sliding_window, ) * 2
170+
scheduler_metadata = _get_meta_flashattn(
171+
batch_size=batch_size,
172+
max_seqlen_q=max_seqlen_q,
173+
max_seqlen_k=step_context.max_kv_seqlen,
174+
num_heads_q=step_context.model_config.num_attention_heads,
175+
num_heads_kv=step_context.model_config.num_key_value_heads,
176+
headdim=step_context.model_config.head_dim,
177+
cache_seqlens=attn_metadata.kv_seqlens.to(torch.int32),
178+
qkv_dtype=step_context.model_config.dtype,
179+
page_size=block_size,
180+
window_size=window_size,
181+
)
182+
attn_metadata.scheduler_metadata = scheduler_metadata
183+
attn_metadata.max_kv_seqlen = step_context.max_kv_seqlen
184+
return attn_metadata
185+
124186
@classmethod
125187
def update_step_context(cls, step_context):
126188
"""Update step context."""
@@ -135,9 +197,10 @@ def update_step_context(cls, step_context):
135197
cu_seqlens_q = None
136198
cu_seqlens_k = None
137199
if use_flash_mla or use_flash_attn3:
138-
cu_seqlens_q = torch.nn.functional.pad(torch.cumsum(q_seqlens, dim=0, dtype=torch.int32), (1, 0))
139-
cu_seqlens_k = torch.nn.functional.pad(torch.cumsum(kv_seqlens, dim=0, dtype=torch.int32), (1, 0))
140200
step_context.block_offsets = step_context.block_offsets.to(torch.int32)
201+
if not step_context.is_decoding:
202+
cu_seqlens_q = torch.nn.functional.pad(torch.cumsum(q_seqlens, dim=0, dtype=torch.int32), (1, 0))
203+
cu_seqlens_k = torch.nn.functional.pad(torch.cumsum(kv_seqlens, dim=0, dtype=torch.int32), (1, 0))
141204

142205
if not step_context.is_decoding:
143206
kv_start_loc = kv_seqlens.cumsum(0) - kv_seqlens
@@ -160,6 +223,10 @@ def update_step_context(cls, step_context):
160223
cls.update_meta_flashmla(attn_metadata,
161224
step_context.model_config.num_attention_heads * decode_query_len)
162225

226+
if use_flash_attn3:
227+
if step_context.is_decoding is True:
228+
attn_metadata = cls.update_meta_flashattn(attn_metadata, step_context)
229+
163230
cross_seqlens = step_context.cross_seqlens
164231
cross_kv_seqlens = step_context.cross_kv_seqlens
165232
cross_attn_metadata = None

lmdeploy/pytorch/engine/engine.py

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -854,7 +854,6 @@ def update_running_migration(self, running: SeqList, next_token_ids: np.ndarray,
854854

855855
def _debug_spec_stats(self, batched_outputs: BatchedOutputs, is_decoding: bool = False):
856856
"""Make spec stats."""
857-
# if self.speculative_config is not None and (debug or self.engine_config.enable_metrics):
858857
if self.speculative_config is not None:
859858
if not hasattr(self, 'spec_stats'):
860859
from lmdeploy.metrics.stats import SpeculativeDecodingStats
@@ -880,7 +879,7 @@ def _make_infer_outputs(
880879
logprobs = batched_outputs.logprobs
881880

882881
# for debug
883-
debug = True
882+
debug = False
884883
if debug:
885884
self._debug_spec_stats(batched_outputs, is_decoding=is_decoding)
886885

@@ -912,8 +911,13 @@ def _make_infer_outputs(
912911
cur_logprobs = None
913912
if num_logprobs >= 0:
914913
cur_logprobs = (logprobs.vals[idx, :num_logprobs + 1], logprobs.indices[idx, :num_logprobs + 1])
915-
916-
req_metrics = RequestMetrics(new_token_timestamp, msg.engine_events, spec_info=None)
914+
# get spec stats info
915+
spec_info = None
916+
if self.speculative_config is not None and is_decoding and self.engine_config.enable_metrics:
917+
num_draft_tokens = self.speculative_config.num_speculative_tokens
918+
num_accepted_tokens = (batched_outputs.next_token_ids[idx] > -1).sum() - 1
919+
spec_info = dict(num_draft_tokens=num_draft_tokens, num_accepted_tokens=num_accepted_tokens)
920+
req_metrics = RequestMetrics(new_token_timestamp, msg.engine_events, spec_info=spec_info)
917921
out = InferOutput(session_id=session_id,
918922
resp=msg.resp,
919923
finish=finish,

lmdeploy/pytorch/engine/model_agent.py

Lines changed: 9 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -514,7 +514,7 @@ def get_output(self):
514514
self._output.numel() // self._output.size(-1),
515515
device=self._output.device,
516516
dtype=self._output.dtype)
517-
return strategy.slice_outputs(self._output, seqlen)
517+
return strategy.slice_outputs(self._output, seqlen), self._aux_output
518518
torch.cuda.synchronize()
519519
if self._aux_output is not None:
520520
self._aux_output = self._aux_output.to(self._device)
@@ -796,13 +796,14 @@ async def __prepare_dp():
796796
logger.debug(f'<ForwardTask> rank[{rank}]: Output [{idx}]')
797797
extra_outputs = self.agent_strategy.make_extra_outputs(extra_inputs)
798798
self._push_output(
799-
BatchedOutputs(next_token_ids=next_token_ids if self.spec_agent is None else extra_inputs.output_token_ids,
800-
logits=logits if return_logits else None,
801-
stopped=stopped,
802-
stop_pos=stop_pos,
803-
model_metas=model_metas,
804-
logprobs=logprobs,
805-
extra_outputs=extra_outputs))
799+
BatchedOutputs(
800+
next_token_ids=next_token_ids if self.spec_agent is None else extra_inputs.output_token_ids,
801+
logits=logits if return_logits else None,
802+
stopped=stopped,
803+
stop_pos=stop_pos,
804+
model_metas=model_metas,
805+
logprobs=logprobs,
806+
extra_outputs=extra_outputs))
806807
else:
807808
# Avoid adding the ADInplaceOrView dispatch key to `next_token_ids`,
808809
# as it can trigger recompilation on different ranks when using torch.compile.

lmdeploy/pytorch/model_inputs.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -314,6 +314,7 @@ class StepContext:
314314
kv_caches: List
315315
is_decoding: bool
316316
sum_kv_seqlen: int
317+
max_kv_seqlen: int = None
317318
local_adapter_ids: torch.LongTensor = None
318319
input_embeddings: torch.Tensor = None
319320
input_embedding_indexing: torch.Tensor = None
@@ -388,6 +389,7 @@ def new(
388389
kv_caches=kv_caches,
389390
is_decoding=inputs.is_decoding,
390391
sum_kv_seqlen=inputs.sum_kv_seqlen,
392+
max_kv_seqlen=inputs.max_kv_seqlen,
391393
local_adapter_ids=inputs.local_adapter_ids,
392394
vision_inputs=inputs.vision_inputs,
393395
kv_quant_policy=kv_quant_policy,

lmdeploy/pytorch/models/utils/cudagraph.py

Lines changed: 28 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -69,8 +69,7 @@ def make_buffers_cudagraph(self, graph_meta: CudaGraphMeta, *args, **kwargs) ->
6969
seqlens_dtype = torch.int64
7070
use_flash_mla = getattr(self.config, 'use_flash_mla', False)
7171
use_flash_attn3 = getattr(self.config, 'use_flash_attn3', False)
72-
if use_flash_attn3 and not graph_meta.is_decoding:
73-
seqlens_dtype = torch.int32
72+
7473
if use_flash_mla is True:
7574
import flash_mla
7675
if graph_meta.is_decoding:
@@ -79,6 +78,9 @@ def make_buffers_cudagraph(self, graph_meta: CudaGraphMeta, *args, **kwargs) ->
7978
input_buffers['tile_scheduler_metadata'], input_buffers['num_splits'] = flash_mla.get_mla_metadata(
8079
torch.ones(max_batches, dtype=torch.int32, device=device),
8180
self.config.num_attention_heads * decode_query_len, 1)
81+
elif use_flash_attn3 is True:
82+
seqlens_dtype = torch.int32
83+
input_buffers['scheduler_metadata'] = torch.zeros(max_batches + 1, dtype=torch.int32, device=device)
8284

8385
# flash_mla requires block_offsets and kv_lens int32
8486
input_buffers['block_offsets'] = torch.zeros((max_batches, num_blocks), dtype=seqlens_dtype, device=device)
@@ -129,7 +131,11 @@ def fill_buffers_cudagraph(self, graph_meta: CudaGraphMeta, input_ids: Tensor, p
129131
attn_metadata.q_start_loc = input_buffers['q_start_loc']
130132
attn_metadata.q_seqlens = input_buffers['q_seqlens']
131133
attn_metadata.kv_seqlens = input_buffers['kv_seqlens']
132-
if getattr(self.config, 'use_flash_mla', False) is True:
134+
135+
use_flash_mla = getattr(self.config, 'use_flash_mla', False)
136+
use_flash_attn3 = getattr(self.config, 'use_flash_attn3', False)
137+
138+
if use_flash_mla is True:
133139
import flash_mla
134140
tile_scheduler_metadata, num_splits = flash_mla.get_mla_metadata(
135141
attn_metadata.kv_seqlens.to(torch.int32), self.config.num_attention_heads * decode_query_len, 1)
@@ -139,6 +145,25 @@ def fill_buffers_cudagraph(self, graph_meta: CudaGraphMeta, input_ids: Tensor, p
139145
attn_metadata.tile_scheduler_metadata = input_buffers['tile_scheduler_metadata']
140146
attn_metadata.num_splits = input_buffers['num_splits']
141147

148+
if use_flash_attn3:
149+
from flash_attn_interface import get_scheduler_metadata
150+
block_size = past_key_values[0][0].size(1)
151+
# TODO may check tp>1?
152+
scheduler_metadata = get_scheduler_metadata(
153+
batch_size=batch_size,
154+
max_seqlen_q=decode_query_len,
155+
max_seqlen_k=attn_metadata.max_kv_seqlen,
156+
num_heads_q=self.config.num_attention_heads,
157+
num_heads_kv=self.config.num_key_value_heads,
158+
headdim=self.config.head_dim,
159+
cache_seqlens=attn_metadata.kv_seqlens.to(torch.int32),
160+
qkv_dtype=self.config.torch_dtype,
161+
page_size=block_size,
162+
)
163+
input_buffers['scheduler_metadata'].zero_()
164+
input_buffers['scheduler_metadata'][:batch_size + 1].copy_(scheduler_metadata[:batch_size + 1])
165+
attn_metadata.scheduler_metadata = input_buffers['scheduler_metadata']
166+
142167
new_inputs = dict(
143168
past_key_values=past_key_values,
144169
attn_metadata=attn_metadata,

0 commit comments

Comments
 (0)