@@ -35,6 +35,7 @@ class SequenceContext:
35
35
block_table : torch .Tensor | None = None
36
36
device : str | torch .device = "cpu" # TODO: 这个地方有点乱,到处是 device
37
37
position_ids : torch .LongTensor | None = None
38
+ cu_seq_lens_pad_len : int = 0 # 用于记录 cu_seq_lens pad 的长度,方便在 pad_cu_seq_lens 中恢复
38
39
39
40
# Intern-S1
40
41
image_flags : torch .LongTensor | None = None
@@ -57,6 +58,8 @@ def __post_init__(self):
57
58
58
59
self .position_ids = position_ids
59
60
61
+ self .pad_cu_seq_lens ()
62
+
60
63
@classmethod
61
64
def from_input_ids (
62
65
cls ,
@@ -98,15 +101,27 @@ def split(self, sequence_parallel_mesh: DeviceMesh | None = None) -> Self:
98
101
)
99
102
new_padding = pad_input_ids .numel () - self .input_ids .numel ()
100
103
if new_padding > 0 :
101
- if self .num_padding > 0 :
102
- new_cu_seq_lens = self .cu_seq_lens_q .clone ()
103
- new_cu_seq_lens [- 1 ] += new_padding
104
+ if self .cu_seq_lens_pad_len == 0 :
105
+ if self .num_padding > 0 :
106
+ new_cu_seq_lens = self .cu_seq_lens_q .clone ()
107
+ new_cu_seq_lens [- 1 ] += new_padding
108
+ else :
109
+ new_cu_seq_lens = torch .ones (
110
+ self .cu_seq_lens_q .numel () + 1 , dtype = torch .int32 , device = self .device
111
+ )
112
+ new_cu_seq_lens [: self .cu_seq_lens_q .numel ()] = self .cu_seq_lens_q .clone ()
113
+ new_cu_seq_lens [- 1 ] = self .cu_seq_lens_q [- 1 ] + new_padding
104
114
else :
105
- new_cu_seq_lens = torch .ones (self .cu_seq_lens_q .numel () + 1 , dtype = torch .int32 , device = self .device )
106
- new_cu_seq_lens [: self .cu_seq_lens_q .numel ()] = self .cu_seq_lens_q .clone ()
107
- new_cu_seq_lens [- 1 ] = self .cu_seq_lens_q [- 1 ] + new_padding
115
+ new_cu_seq_lens = self .cu_seq_lens_q .clone ()
116
+ if self .num_padding > 0 :
117
+ new_cu_seq_lens [- (self .cu_seq_lens_pad_len + 1 ) :].add_ (new_padding )
118
+ else :
119
+ new_cu_seq_lens [- self .cu_seq_lens_pad_len :].add_ (new_padding )
120
+ # 有一个 cu_seq_lens 的元素从没有意义的 cu_seq_lens pad 变得有实际意义了(虽然对应的是 pad tokens)
121
+ self .cu_seq_lens_pad_len -= 1
108
122
else :
109
123
new_cu_seq_lens = self .cu_seq_lens_q .clone ()
124
+
110
125
new_cu_seq_lens = cast (torch .IntTensor , new_cu_seq_lens )
111
126
new_max_length = cast (int , max (self .seq_lens_q .max ().item (), new_padding ))
112
127
num_non_padding = self .input_ids .shape [1 ] - self .num_padding
@@ -142,21 +157,26 @@ def pack(cls, sequence_context_list: list["SequenceContext"]):
142
157
num_padding = 0
143
158
device = []
144
159
inputs_embeds = []
145
- for seq_ctx in sequence_context_list :
160
+ cu_seq_lens_is_padded = False
161
+ for i , seq_ctx in enumerate (sequence_context_list ):
146
162
assert seq_ctx .sequence_parallel_mesh is None
147
163
# todo: support vlm model
148
164
assert seq_ctx .pixel_values is None
149
165
packed_input_ids .append (seq_ctx .input_ids )
150
- cu_seq_lens_q .append (
151
- seq_ctx .cu_seq_lens_q # type: ignore
152
- if len (cu_seq_lens_q ) == 0
153
- else (seq_ctx .cu_seq_lens_q + cu_seq_lens_q [- 1 ][- 1 ])[1 :]
154
- )
155
- cu_seq_lens_k .append (
156
- seq_ctx .cu_seq_lens_k # type: ignore
157
- if len (cu_seq_lens_k ) == 0
158
- else (seq_ctx .cu_seq_lens_k + cu_seq_lens_k [- 1 ][- 1 ])[1 :]
159
- )
166
+ if seq_ctx .cu_seq_lens_pad_len != 0 :
167
+ new_cu_seq_lens_q = seq_ctx .cu_seq_lens_q .clone ()
168
+ new_cu_seq_lens_k = seq_ctx .cu_seq_lens_k .clone ()
169
+ new_cu_seq_lens_q = new_cu_seq_lens_q [: - seq_ctx .cu_seq_lens_pad_len ]
170
+ new_cu_seq_lens_k = new_cu_seq_lens_k [: - seq_ctx .cu_seq_lens_pad_len ]
171
+ cu_seq_lens_is_padded = True
172
+ else :
173
+ new_cu_seq_lens_q = seq_ctx .cu_seq_lens_q .clone ()
174
+ new_cu_seq_lens_k = seq_ctx .cu_seq_lens_k .clone ()
175
+ if i > 0 :
176
+ new_cu_seq_lens_q = (new_cu_seq_lens_q + cu_seq_lens_q [- 1 ][- 1 ])[1 :]
177
+ new_cu_seq_lens_k = (new_cu_seq_lens_k + cu_seq_lens_k [- 1 ][- 1 ])[1 :]
178
+ cu_seq_lens_q .append (new_cu_seq_lens_q )
179
+ cu_seq_lens_k .append (new_cu_seq_lens_k )
160
180
max_length_q = max (max_length_q , seq_ctx .max_length_q )
161
181
max_length_k = max (max_length_k , seq_ctx .max_length_k )
162
182
num_padding += seq_ctx .num_padding
@@ -165,7 +185,7 @@ def pack(cls, sequence_context_list: list["SequenceContext"]):
165
185
inputs_embeds .append (seq_ctx .inputs_embeds )
166
186
assert len (set (device )) == 1 , f"All sequence contexts must be on the same device. Got { set (device )} "
167
187
168
- return cls (
188
+ out = cls (
169
189
input_ids = torch .cat (packed_input_ids , dim = 1 ), # type: ignore
170
190
cu_seq_lens_q = torch .cat (cu_seq_lens_q , dim = 0 ), # type: ignore
171
191
cu_seq_lens_k = torch .cat (cu_seq_lens_k , dim = 0 ), # type: ignore
@@ -176,6 +196,11 @@ def pack(cls, sequence_context_list: list["SequenceContext"]):
176
196
inputs_embeds = torch .cat (inputs_embeds , dim = 1 ) if inputs_embeds else None , # type: ignore
177
197
)
178
198
199
+ if cu_seq_lens_is_padded :
200
+ out = out .pad_cu_seq_lens ()
201
+
202
+ return out
203
+
179
204
@property
180
205
def mask (self ) -> torch .BoolTensor :
181
206
mask : torch .BoolTensor
@@ -189,14 +214,19 @@ def mask(self) -> torch.BoolTensor:
189
214
190
215
@property
191
216
def seq_lens_q (self ) -> torch .LongTensor :
217
+ # 这里不能把 pad 的 cu_seq_lens slice 掉,否则又会把不同 shape 的 cu_seq_lens 暴露给 torch compile
192
218
return self .cu_seq_lens_q [1 :] - self .cu_seq_lens_q [:- 1 ] # type: ignore
193
219
194
220
@property
195
221
def seq_lens_k (self ) -> torch .LongTensor :
222
+ # 这里不能把 pad 的 cu_seq_lens slice 掉,否则又会把不同 shape 的 cu_seq_lens 暴露给 torch compile
196
223
return self .cu_seq_lens_k [1 :] - self .cu_seq_lens_k [:- 1 ] # type: ignore
197
224
198
225
# TODO: 暂时没有用到,可能要删掉
199
226
def chunk (self , num_chunks : int ) -> list [Self ]:
227
+ # 暂时没用,就先不支持 cu_seq_lens pad 了
228
+ if self .pad_cu_seq_lens is not None and self .pad_cu_seq_lens != 0 :
229
+ raise NotImplementedError
200
230
n = self .seq_lens_q .numel ()
201
231
assert n // num_chunks
202
232
n_per_chunk = n // num_chunks
@@ -233,6 +263,42 @@ def set_sp_mesh(self, sp_mesh: DeviceMesh) -> Self:
233
263
self .sequence_parallel_mesh = sp_mesh
234
264
return self
235
265
266
+ def pad_cu_seq_lens (self ) -> Self :
267
+ """Pad the cumulative sequence lengths to the specified maximum length.
268
+
269
+ In large-scale training (1024 GPUs or more), varying data leads to different
270
+ cu_seq_lens shapes, causing frequent recompilations when using torch.compile
271
+ for optimization and significantly slowing down training.
272
+ To address this, we pad cu_seq_lens to a fixed shape (inferred from seq_len)
273
+ and slice out the padded content during attention calculation using torch.library.custom_op,
274
+ ensuring training behavior remains unaffected.
275
+
276
+ Args:
277
+ max_len: The target maximum length for padding.
278
+
279
+ Returns:
280
+ Self: The context with padded cumulative sequence lengths.
281
+ """
282
+ current_len = self .cu_seq_lens_q .shape [0 ]
283
+ seq_len = self .input_ids .shape [1 ]
284
+ cu_seq_lens_max_len_estimation = seq_len // 64 + 1
285
+ if self .cu_seq_lens_pad_len != 0 :
286
+ assert current_len == cu_seq_lens_max_len_estimation
287
+ # assert self.cu_seq_lens_pad_len == 0, "pad_cu_seq_lens should only be called once."
288
+ if current_len >= cu_seq_lens_max_len_estimation :
289
+ return self
290
+ pad_len = cu_seq_lens_max_len_estimation - current_len
291
+ self .cu_seq_lens_pad_len = pad_len
292
+ assert torch .equal (self .cu_seq_lens_q , self .cu_seq_lens_k ), (
293
+ "cu_seq_lens_q and cu_seq_lens_k must be equal to pad."
294
+ )
295
+ pad_tensor = torch .full (
296
+ (pad_len ,), self .cu_seq_lens_q [- 1 ], dtype = self .cu_seq_lens_q .dtype , device = self .cu_seq_lens_q .device
297
+ )
298
+ self .cu_seq_lens_q = torch .cat ([self .cu_seq_lens_q , pad_tensor ], dim = 0 )
299
+ self .cu_seq_lens_k = torch .cat ([self .cu_seq_lens_k , pad_tensor ], dim = 0 )
300
+ return self
301
+
236
302
def to (self , device : torch .device | str ):
237
303
"""Move all tensors in the context to the specified device.
238
304
0 commit comments