Skip to content

Commit 874f04a

Browse files
committed
pad cu_seq_lens to avoid torch recompile
1 parent 412bf9f commit 874f04a

File tree

8 files changed

+125
-22
lines changed

8 files changed

+125
-22
lines changed

ci/scripts/test_sft_trainer.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -225,7 +225,7 @@ def main():
225225
optim_cfg = AdamWConfig(lr=6e-05)
226226
lr_cfg = LRConfig(lr_type="cosine", lr_min=1e-6)
227227
fsdp_cfg = FSDPConfig(
228-
torch_compile=False, #get_device() == "cuda",
228+
torch_compile=True, #get_device() == "cuda",
229229
cpu_offload=False,
230230
ep_size=moe_cfg.ep_size,
231231
# hsdp_sharding_size=4,
@@ -253,7 +253,7 @@ def main():
253253
loss_cfg=loss_cfg,
254254
lr_cfg=lr_cfg,
255255
tokenizer_path=QWEN3_MOE_PATH,
256-
global_batch_size=16,
256+
global_batch_size=8,
257257
total_epoch=1,
258258
work_dir=work_dir,
259259
seed=0,

xtuner/v1/data_proto/sequence_context.py

Lines changed: 84 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@ class SequenceContext:
3535
block_table: torch.Tensor | None = None
3636
device: str | torch.device = "cpu" # TODO: 这个地方有点乱,到处是 device
3737
position_ids: torch.LongTensor | None = None
38+
cu_seq_lens_pad_len: int = 0 # 用于记录 cu_seq_lens pad 的长度,方便在 pad_cu_seq_lens 中恢复
3839

3940
# Intern-S1
4041
image_flags: torch.LongTensor | None = None
@@ -57,6 +58,8 @@ def __post_init__(self):
5758

5859
self.position_ids = position_ids
5960

61+
self.pad_cu_seq_lens()
62+
6063
@classmethod
6164
def from_input_ids(
6265
cls,
@@ -98,15 +101,27 @@ def split(self, sequence_parallel_mesh: DeviceMesh | None = None) -> Self:
98101
)
99102
new_padding = pad_input_ids.numel() - self.input_ids.numel()
100103
if new_padding > 0:
101-
if self.num_padding > 0:
102-
new_cu_seq_lens = self.cu_seq_lens_q.clone()
103-
new_cu_seq_lens[-1] += new_padding
104+
if self.cu_seq_lens_pad_len == 0:
105+
if self.num_padding > 0:
106+
new_cu_seq_lens = self.cu_seq_lens_q.clone()
107+
new_cu_seq_lens[-1] += new_padding
108+
else:
109+
new_cu_seq_lens = torch.ones(
110+
self.cu_seq_lens_q.numel() + 1, dtype=torch.int32, device=self.device
111+
)
112+
new_cu_seq_lens[: self.cu_seq_lens_q.numel()] = self.cu_seq_lens_q.clone()
113+
new_cu_seq_lens[-1] = self.cu_seq_lens_q[-1] + new_padding
104114
else:
105-
new_cu_seq_lens = torch.ones(self.cu_seq_lens_q.numel() + 1, dtype=torch.int32, device=self.device)
106-
new_cu_seq_lens[: self.cu_seq_lens_q.numel()] = self.cu_seq_lens_q.clone()
107-
new_cu_seq_lens[-1] = self.cu_seq_lens_q[-1] + new_padding
115+
new_cu_seq_lens = self.cu_seq_lens_q.clone()
116+
if self.num_padding > 0:
117+
new_cu_seq_lens[-(self.cu_seq_lens_pad_len + 1) :].add_(new_padding)
118+
else:
119+
new_cu_seq_lens[-self.cu_seq_lens_pad_len :].add_(new_padding)
120+
# 有一个 cu_seq_lens 的元素从没有意义的 cu_seq_lens pad 变得有实际意义了(虽然对应的是 pad tokens)
121+
self.cu_seq_lens_pad_len -= 1
108122
else:
109123
new_cu_seq_lens = self.cu_seq_lens_q.clone()
124+
110125
new_cu_seq_lens = cast(torch.IntTensor, new_cu_seq_lens)
111126
new_max_length = cast(int, max(self.seq_lens_q.max().item(), new_padding))
112127
num_non_padding = self.input_ids.shape[1] - self.num_padding
@@ -142,21 +157,26 @@ def pack(cls, sequence_context_list: list["SequenceContext"]):
142157
num_padding = 0
143158
device = []
144159
inputs_embeds = []
145-
for seq_ctx in sequence_context_list:
160+
cu_seq_lens_is_padded = False
161+
for i, seq_ctx in enumerate(sequence_context_list):
146162
assert seq_ctx.sequence_parallel_mesh is None
147163
# todo: support vlm model
148164
assert seq_ctx.pixel_values is None
149165
packed_input_ids.append(seq_ctx.input_ids)
150-
cu_seq_lens_q.append(
151-
seq_ctx.cu_seq_lens_q # type: ignore
152-
if len(cu_seq_lens_q) == 0
153-
else (seq_ctx.cu_seq_lens_q + cu_seq_lens_q[-1][-1])[1:]
154-
)
155-
cu_seq_lens_k.append(
156-
seq_ctx.cu_seq_lens_k # type: ignore
157-
if len(cu_seq_lens_k) == 0
158-
else (seq_ctx.cu_seq_lens_k + cu_seq_lens_k[-1][-1])[1:]
159-
)
166+
if seq_ctx.cu_seq_lens_pad_len != 0:
167+
new_cu_seq_lens_q = seq_ctx.cu_seq_lens_q.clone()
168+
new_cu_seq_lens_k = seq_ctx.cu_seq_lens_k.clone()
169+
new_cu_seq_lens_q = new_cu_seq_lens_q[: -seq_ctx.cu_seq_lens_pad_len]
170+
new_cu_seq_lens_k = new_cu_seq_lens_k[: -seq_ctx.cu_seq_lens_pad_len]
171+
cu_seq_lens_is_padded = True
172+
else:
173+
new_cu_seq_lens_q = seq_ctx.cu_seq_lens_q.clone()
174+
new_cu_seq_lens_k = seq_ctx.cu_seq_lens_k.clone()
175+
if i > 0:
176+
new_cu_seq_lens_q = (new_cu_seq_lens_q + cu_seq_lens_q[-1][-1])[1:]
177+
new_cu_seq_lens_k = (new_cu_seq_lens_k + cu_seq_lens_k[-1][-1])[1:]
178+
cu_seq_lens_q.append(new_cu_seq_lens_q)
179+
cu_seq_lens_k.append(new_cu_seq_lens_k)
160180
max_length_q = max(max_length_q, seq_ctx.max_length_q)
161181
max_length_k = max(max_length_k, seq_ctx.max_length_k)
162182
num_padding += seq_ctx.num_padding
@@ -165,7 +185,7 @@ def pack(cls, sequence_context_list: list["SequenceContext"]):
165185
inputs_embeds.append(seq_ctx.inputs_embeds)
166186
assert len(set(device)) == 1, f"All sequence contexts must be on the same device. Got {set(device)}"
167187

168-
return cls(
188+
out = cls(
169189
input_ids=torch.cat(packed_input_ids, dim=1), # type: ignore
170190
cu_seq_lens_q=torch.cat(cu_seq_lens_q, dim=0), # type: ignore
171191
cu_seq_lens_k=torch.cat(cu_seq_lens_k, dim=0), # type: ignore
@@ -176,6 +196,11 @@ def pack(cls, sequence_context_list: list["SequenceContext"]):
176196
inputs_embeds=torch.cat(inputs_embeds, dim=1) if inputs_embeds else None, # type: ignore
177197
)
178198

199+
if cu_seq_lens_is_padded:
200+
out = out.pad_cu_seq_lens()
201+
202+
return out
203+
179204
@property
180205
def mask(self) -> torch.BoolTensor:
181206
mask: torch.BoolTensor
@@ -189,14 +214,19 @@ def mask(self) -> torch.BoolTensor:
189214

190215
@property
191216
def seq_lens_q(self) -> torch.LongTensor:
217+
# 这里不能把 pad 的 cu_seq_lens slice 掉,否则又会把不同 shape 的 cu_seq_lens 暴露给 torch compile
192218
return self.cu_seq_lens_q[1:] - self.cu_seq_lens_q[:-1] # type: ignore
193219

194220
@property
195221
def seq_lens_k(self) -> torch.LongTensor:
222+
# 这里不能把 pad 的 cu_seq_lens slice 掉,否则又会把不同 shape 的 cu_seq_lens 暴露给 torch compile
196223
return self.cu_seq_lens_k[1:] - self.cu_seq_lens_k[:-1] # type: ignore
197224

198225
# TODO: 暂时没有用到,可能要删掉
199226
def chunk(self, num_chunks: int) -> list[Self]:
227+
# 暂时没用,就先不支持 cu_seq_lens pad 了
228+
if self.pad_cu_seq_lens is not None and self.pad_cu_seq_lens != 0:
229+
raise NotImplementedError
200230
n = self.seq_lens_q.numel()
201231
assert n // num_chunks
202232
n_per_chunk = n // num_chunks
@@ -233,6 +263,42 @@ def set_sp_mesh(self, sp_mesh: DeviceMesh) -> Self:
233263
self.sequence_parallel_mesh = sp_mesh
234264
return self
235265

266+
def pad_cu_seq_lens(self) -> Self:
267+
"""Pad the cumulative sequence lengths to the specified maximum length.
268+
269+
In large-scale training (1024 GPUs or more), varying data leads to different
270+
cu_seq_lens shapes, causing frequent recompilations when using torch.compile
271+
for optimization and significantly slowing down training.
272+
To address this, we pad cu_seq_lens to a fixed shape (inferred from seq_len)
273+
and slice out the padded content during attention calculation using torch.library.custom_op,
274+
ensuring training behavior remains unaffected.
275+
276+
Args:
277+
max_len: The target maximum length for padding.
278+
279+
Returns:
280+
Self: The context with padded cumulative sequence lengths.
281+
"""
282+
current_len = self.cu_seq_lens_q.shape[0]
283+
seq_len = self.input_ids.shape[1]
284+
cu_seq_lens_max_len_estimation = seq_len // 64 + 1
285+
if self.cu_seq_lens_pad_len != 0:
286+
assert current_len == cu_seq_lens_max_len_estimation
287+
# assert self.cu_seq_lens_pad_len == 0, "pad_cu_seq_lens should only be called once."
288+
if current_len >= cu_seq_lens_max_len_estimation:
289+
return self
290+
pad_len = cu_seq_lens_max_len_estimation - current_len
291+
self.cu_seq_lens_pad_len = pad_len
292+
assert torch.equal(self.cu_seq_lens_q, self.cu_seq_lens_k), (
293+
"cu_seq_lens_q and cu_seq_lens_k must be equal to pad."
294+
)
295+
pad_tensor = torch.full(
296+
(pad_len,), self.cu_seq_lens_q[-1], dtype=self.cu_seq_lens_q.dtype, device=self.cu_seq_lens_q.device
297+
)
298+
self.cu_seq_lens_q = torch.cat([self.cu_seq_lens_q, pad_tensor], dim=0)
299+
self.cu_seq_lens_k = torch.cat([self.cu_seq_lens_k, pad_tensor], dim=0)
300+
return self
301+
236302
def to(self, device: torch.device | str):
237303
"""Move all tensors in the context to the specified device.
238304

xtuner/v1/module/attention/kv_cache.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,10 +10,14 @@ def fill_paged_kv_cache(
1010
value_cache: torch.Tensor,
1111
cu_seq_lens_q: torch.Tensor,
1212
cu_seq_lens_k: torch.Tensor,
13+
cu_seqlens_pad_len: int,
1314
max_length_q: int,
1415
max_length_k: int,
1516
block_table: torch.Tensor,
1617
) -> None:
18+
if cu_seqlens_pad_len > 0:
19+
cu_seq_lens_q = cu_seq_lens_q[:-cu_seqlens_pad_len]
20+
cu_seq_lens_k = cu_seq_lens_k[:-cu_seqlens_pad_len]
1721
bs = block_table.size(0)
1822
from lmdeploy.pytorch.kernels import fill_kv_cache
1923

@@ -40,6 +44,7 @@ def fill_paged_kv_cache_fake(
4044
value_cache: torch.Tensor,
4145
cu_seq_lens_q: torch.Tensor,
4246
cu_seq_lens_k: torch.Tensor,
47+
cu_seqlens_pad_len: int,
4348
max_length_q: int,
4449
max_length_k: int,
4550
block_table: torch.Tensor,

xtuner/v1/module/attention/mha.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -216,6 +216,7 @@ def prefilling(
216216
past_key_values[self.layer_idx][1],
217217
seq_ctx.cu_seq_lens_q,
218218
seq_ctx.cu_seq_lens_k,
219+
seq_ctx.cu_seq_lens_pad_len,
219220
seq_ctx.max_length_q,
220221
seq_ctx.max_length_k,
221222
seq_ctx.block_table,
@@ -233,6 +234,7 @@ def prefilling(
233234
value_states.transpose(1, 2).squeeze(0),
234235
cu_seqlens_q=seq_ctx.cu_seq_lens_q,
235236
cu_seqlens_k=seq_ctx.cu_seq_lens_k,
237+
cu_seqlens_pad_len=seq_ctx.cu_seq_lens_pad_len,
236238
max_seqlen_q=seq_ctx.max_length_q,
237239
max_seqlen_k=seq_ctx.max_length_k,
238240
dropout_p=self.dropout,
@@ -253,6 +255,8 @@ def decoding(
253255
) -> torch.Tensor:
254256
assert seq_ctx.block_table is not None
255257
assert self.layer_idx is not None
258+
if seq_ctx.cu_seq_lens_pad_len != 0:
259+
raise NotImplementedError
256260

257261
input_shape = hidden_states.shape[:-1]
258262
hidden_shape = (*input_shape, -1, self.head_dim)
@@ -387,6 +391,7 @@ def forward(
387391
value_states,
388392
cu_seqlens_q=seq_ctx.cu_seq_lens_q,
389393
cu_seqlens_k=seq_ctx.cu_seq_lens_k,
394+
cu_seqlens_pad_len=seq_ctx.cu_seq_lens_pad_len,
390395
max_seqlen_q=seq_ctx.max_length_q,
391396
max_seqlen_k=seq_ctx.max_length_k,
392397
window_size=self.window_size,

xtuner/v1/module/attention/mla.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -324,6 +324,7 @@ def forward_training(
324324
value_states.transpose(1, 2).squeeze(0),
325325
cu_seqlens_q=attn_meta.cu_seq_lens_q,
326326
cu_seqlens_k=attn_meta.cu_seq_lens_k,
327+
cu_seqlens_pad_len=attn_meta.cu_seq_lens_pad_len,
327328
max_seqlen_q=attn_meta.max_length_q,
328329
max_seqlen_k=attn_meta.max_length_k,
329330
dropout_p=self.dropout,
@@ -349,6 +350,8 @@ def prefilling(
349350
seq_ctx: SequenceContext,
350351
past_key_values: list[list[torch.Tensor]],
351352
) -> torch.Tensor:
353+
if seq_ctx.cu_seq_lens_pad_len != 0:
354+
raise NotImplementedError
352355
bsz, q_len, _ = hidden_states.size()
353356

354357
if self.q_lora_rank is None:
@@ -451,6 +454,8 @@ def decoding(
451454
seq_ctx: SequenceContext,
452455
past_key_values: list[list[torch.Tensor]],
453456
) -> torch.Tensor:
457+
if seq_ctx.cu_seq_lens_pad_len != 0:
458+
raise NotImplementedError
454459
bsz, q_len, _ = hidden_states.size()
455460

456461
if self.q_lora_rank is None:
@@ -606,6 +611,7 @@ def forward(
606611
value_states.transpose(1, 2).squeeze(0),
607612
cu_seqlens_q=seq_ctx.cu_seq_lens_q,
608613
cu_seqlens_k=seq_ctx.cu_seq_lens_k,
614+
cu_seqlens_pad_len=seq_ctx.cu_seq_lens_pad_len,
609615
max_seqlen_q=seq_ctx.max_length_q,
610616
max_seqlen_k=seq_ctx.max_length_k,
611617
dropout_p=self.dropout,

xtuner/v1/ops/attn_imp.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -210,7 +210,10 @@ def flash_attention(q, k, v, window_size=(-1, -1), s_aux=None, **kwargs) -> torc
210210
attention_output = flash_attn_varlen_func(q, k, v, **kwargs)
211211
else:
212212
cu_seqlens_q = kwargs["cu_seqlens_q"]
213-
attention_output = flash_sink_attn_varlen_func(q, k, v, s_aux, cu_seqlens_q, window_size[0])
213+
cu_seqlens_pad_len = kwargs["cu_seqlens_pad_len"]
214+
attention_output = flash_sink_attn_varlen_func(
215+
q, k, v, s_aux, cu_seqlens_q, cu_seqlens_pad_len, window_size[0]
216+
)
214217
return attention_output[None]
215218

216219

xtuner/v1/ops/flash_attn/flash_sink_varlen_attn_gpt_oss.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -480,6 +480,7 @@ def forward(
480480
v: torch.Tensor,
481481
sink: torch.Tensor,
482482
cu_seqlen: torch.Tensor,
483+
cu_seqlens_pad_len: int,
483484
window_size=None,
484485
):
485486
if window_size == -1:
@@ -492,6 +493,7 @@ def forward(
492493
)
493494

494495
ctx.save_for_backward(q, k, v, o, lse)
496+
ctx.cu_seqlens_pad_len = cu_seqlens_pad_len
495497
ctx.sink = sink
496498
ctx.window_size = window_size
497499
ctx.cu_seqlen = cu_seqlen

0 commit comments

Comments
 (0)