Skip to content

Commit ebb8899

Browse files
authored
⚡ Fix Flash Attention x Padding-Free loss (#4170)
1 parent 70e2017 commit ebb8899

File tree

2 files changed

+71
-107
lines changed

2 files changed

+71
-107
lines changed

tests/test_sft_trainer.py

Lines changed: 24 additions & 55 deletions
Original file line numberDiff line numberDiff line change
@@ -64,9 +64,9 @@ def test_basic_padding(self):
6464

6565
result = self.collator(examples)
6666

67+
self.assertEqual(set(result.keys()), {"input_ids", "attention_mask", "labels"})
6768
torch.testing.assert_close(result["input_ids"], torch.tensor([[1, 2, 3], [4, 5, 0]]))
6869
torch.testing.assert_close(result["attention_mask"], torch.tensor([[1, 1, 1], [1, 1, 0]]))
69-
torch.testing.assert_close(result["position_ids"], torch.tensor([[0, 1, 2], [0, 1, 0]]))
7070
torch.testing.assert_close(result["labels"], torch.tensor([[1, 2, 3], [4, 5, -100]]))
7171

7272
def test_completion_mask(self):
@@ -79,9 +79,9 @@ def test_completion_mask(self):
7979

8080
result = self.collator(examples)
8181

82+
self.assertEqual(set(result.keys()), {"input_ids", "attention_mask", "labels"})
8283
torch.testing.assert_close(result["input_ids"], torch.tensor([[1, 2, 3], [4, 5, 0]]))
8384
torch.testing.assert_close(result["attention_mask"], torch.tensor([[1, 1, 1], [1, 1, 0]]))
84-
torch.testing.assert_close(result["position_ids"], torch.tensor([[0, 1, 2], [0, 1, 0]]))
8585
torch.testing.assert_close(result["labels"], torch.tensor([[-100, 2, 3], [-100, 5, -100]]))
8686

8787
def test_completion_only_loss_disabled(self):
@@ -95,9 +95,9 @@ def test_completion_only_loss_disabled(self):
9595
result = collator(examples)
9696

9797
# Labels should not be masked when completion_only_loss=False
98+
self.assertEqual(set(result.keys()), {"input_ids", "attention_mask", "labels"})
9899
torch.testing.assert_close(result["input_ids"], torch.tensor([[1, 2, 3], [4, 5, 0]]))
99100
torch.testing.assert_close(result["attention_mask"], torch.tensor([[1, 1, 1], [1, 1, 0]]))
100-
torch.testing.assert_close(result["position_ids"], torch.tensor([[0, 1, 2], [0, 1, 0]]))
101101
torch.testing.assert_close(result["labels"], torch.tensor([[1, 2, 3], [4, 5, -100]]))
102102

103103
def test_padding_free_mode(self):
@@ -107,72 +107,42 @@ def test_padding_free_mode(self):
107107

108108
result = collator(examples)
109109

110+
self.assertEqual(set(result.keys()), {"input_ids", "position_ids", "labels"})
110111
torch.testing.assert_close(result["input_ids"], torch.tensor([[1, 2, 3, 4, 5]]))
111-
torch.testing.assert_close(result["attention_mask"], torch.tensor([[1, 1, 1, 1, 1]]))
112112
torch.testing.assert_close(result["position_ids"], torch.tensor([[0, 1, 2, 0, 1]]))
113-
torch.testing.assert_close(result["labels"], torch.tensor([[1, 2, 3, 4, 5]]))
113+
torch.testing.assert_close(result["labels"], torch.tensor([[-100, 2, 3, -100, 5]]))
114114

115115
def test_padding_free_with_completion_mask(self):
116116
"""Test padding-free mode with completion masks."""
117117
collator = DataCollatorForLanguageModeling(pad_token_id=0, padding_free=True)
118118
examples = [
119-
{"input_ids": [1, 2, 3], "completion_mask": [0, 1, 1]},
119+
{"input_ids": [1, 2, 3], "completion_mask": [0, 0, 1]},
120120
{"input_ids": [4, 5], "completion_mask": [1, 1]},
121121
]
122122

123123
result = collator(examples)
124124

125+
self.assertEqual(set(result.keys()), {"input_ids", "position_ids", "labels"})
125126
torch.testing.assert_close(result["input_ids"], torch.tensor([[1, 2, 3, 4, 5]]))
126-
torch.testing.assert_close(result["attention_mask"], torch.tensor([[1, 1, 1, 1, 1]]))
127127
torch.testing.assert_close(result["position_ids"], torch.tensor([[0, 1, 2, 0, 1]]))
128-
torch.testing.assert_close(result["labels"], torch.tensor([[-100, 2, 3, 4, 5]]))
128+
torch.testing.assert_close(result["labels"], torch.tensor([[-100, -100, 3, -100, 5]]))
129129

130-
def test_packing_drops_attention_mask_for_flash_attention(self):
130+
def test_packing(self):
131131
"""Test that when using packing with position_ids, attention_mask is dropped with fa2."""
132-
collator = DataCollatorForLanguageModeling(pad_token_id=0, padding_free=True, return_position_ids=True)
132+
collator = DataCollatorForLanguageModeling(pad_token_id=0, padding_free=True)
133133

134134
# Simulate packed sequences with position_ids that restart (typical of BFD packing)
135135
examples = [
136-
{
137-
"input_ids": [1, 2, 3, 4, 5, 6, 7, 8], # Packed: [1,2,3] + [4,5] + [6,7,8]
138-
"seq_lengths": [3, 2, 3],
139-
}
136+
{"input_ids": [1, 2, 3, 4, 5, 6], "seq_lengths": [3, 3]},
137+
{"input_ids": [7, 8, 9, 10, 11], "seq_lengths": [4, 1]},
140138
]
141139

142140
result = collator(examples)
143141

144-
# Verify that attention_mask is NOT present - this allows FlashAttention to use position_ids
145-
self.assertNotIn("attention_mask", result, "attention_mask should be dropped for packing with position_ids")
146-
147-
# Verify essential keys are present
148-
self.assertIn("input_ids", result)
149-
self.assertIn("position_ids", result)
150-
self.assertIn("labels", result)
151-
152-
# Verify the data is correctly processed
153-
torch.testing.assert_close(result["input_ids"], torch.tensor([[1, 2, 3, 4, 5, 6, 7, 8]]))
154-
torch.testing.assert_close(result["position_ids"], torch.tensor([[0, 1, 2, 0, 1, 0, 1, 2]]))
155-
torch.testing.assert_close(result["labels"], torch.tensor([[1, 2, 3, 4, 5, 6, 7, 8]]))
156-
157-
def test_padding_free_without_position_ids_keeps_attention_mask(self):
158-
"""
159-
Test that padding_free mode without explicit position_ids still creates attention_mask.
160-
"""
161-
collator = DataCollatorForLanguageModeling(pad_token_id=0, padding_free=True, return_position_ids=True)
162-
163-
# Examples without position_ids (not packed)
164-
examples = [{"input_ids": [1, 2, 3, 4, 5]}]
165-
166-
result = collator(examples)
167-
168-
# Should still have attention_mask since no packed position_ids
169-
self.assertIn("attention_mask", result, "attention_mask should be present when no packed position_ids")
170-
self.assertIn("position_ids", result)
171-
self.assertIn("input_ids", result)
172-
173-
torch.testing.assert_close(result["input_ids"], torch.tensor([[1, 2, 3, 4, 5]]))
174-
torch.testing.assert_close(result["attention_mask"], torch.tensor([[1, 1, 1, 1, 1]]))
175-
torch.testing.assert_close(result["position_ids"], torch.tensor([[0, 1, 2, 3, 4]]))
142+
self.assertEqual(set(result.keys()), {"input_ids", "position_ids", "labels"})
143+
torch.testing.assert_close(result["input_ids"], torch.tensor([[1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11]]))
144+
torch.testing.assert_close(result["position_ids"], torch.tensor([[0, 1, 2, 0, 1, 2, 0, 1, 2, 3, 0]]))
145+
torch.testing.assert_close(result["labels"], torch.tensor([[-100, 2, 3, -100, 5, 6, -100, 8, 9, 10, -100]]))
176146

177147
def test_pad_to_multiple_of(self):
178148
"""Test padding to multiple of specified value."""
@@ -181,9 +151,9 @@ def test_pad_to_multiple_of(self):
181151

182152
result = collator(examples)
183153

154+
self.assertEqual(set(result.keys()), {"input_ids", "attention_mask", "labels"})
184155
torch.testing.assert_close(result["input_ids"], torch.tensor([[1, 2, 3, 0], [4, 5, 0, 0]]))
185156
torch.testing.assert_close(result["attention_mask"], torch.tensor([[1, 1, 1, 0], [1, 1, 0, 0]]))
186-
torch.testing.assert_close(result["position_ids"], torch.tensor([[0, 1, 2, 0], [0, 1, 0, 0]]))
187157
torch.testing.assert_close(result["labels"], torch.tensor([[1, 2, 3, -100], [4, 5, -100, -100]]))
188158

189159
def test_pad_to_multiple_of_and_padding_free(self):
@@ -193,21 +163,21 @@ def test_pad_to_multiple_of_and_padding_free(self):
193163

194164
result = collator(examples)
195165

166+
self.assertEqual(set(result.keys()), {"input_ids", "position_ids", "labels"})
196167
torch.testing.assert_close(result["input_ids"], torch.tensor([[1, 2, 3, 4, 5, 0, 0, 0]]))
197-
torch.testing.assert_close(result["attention_mask"], torch.tensor([[1, 1, 1, 1, 1, 0, 0, 0]]))
198168
torch.testing.assert_close(result["position_ids"], torch.tensor([[0, 1, 2, 0, 1, 0, 0, 0]]))
199-
torch.testing.assert_close(result["labels"], torch.tensor([[1, 2, 3, 4, 5, -100, -100, -100]]))
169+
torch.testing.assert_close(result["labels"], torch.tensor([[-100, 2, 3, -100, 5, -100, -100, -100]]))
200170

201-
def test_custom_position_ids(self):
202-
"""Test handling of custom position IDs in examples."""
171+
def test_custom_position_ids_but_no_padding_free(self):
172+
"""Test that custom position_ids are ignored if padding_free is False."""
203173
self.collator = DataCollatorForLanguageModeling(pad_token_id=0)
204174
examples = [{"input_ids": [1, 2, 3], "seq_lengths": [1, 2]}, {"input_ids": [4, 5], "seq_lengths": [2]}]
205175

206176
result = self.collator(examples)
207177

178+
self.assertEqual(set(result.keys()), {"input_ids", "attention_mask", "labels"})
208179
torch.testing.assert_close(result["input_ids"], torch.tensor([[1, 2, 3], [4, 5, 0]]))
209180
torch.testing.assert_close(result["attention_mask"], torch.tensor([[1, 1, 1], [1, 1, 0]]))
210-
torch.testing.assert_close(result["position_ids"], torch.tensor([[0, 0, 1], [0, 1, 0]]))
211181
torch.testing.assert_close(result["labels"], torch.tensor([[1, 2, 3], [4, 5, -100]]))
212182

213183
def test_single_example(self):
@@ -217,9 +187,9 @@ def test_single_example(self):
217187

218188
result = self.collator(examples)
219189

190+
self.assertEqual(set(result.keys()), {"input_ids", "attention_mask", "labels"})
220191
torch.testing.assert_close(result["input_ids"], torch.tensor([[1, 2, 3, 4]]))
221192
torch.testing.assert_close(result["attention_mask"], torch.tensor([[1, 1, 1, 1]]))
222-
torch.testing.assert_close(result["position_ids"], torch.tensor([[0, 1, 2, 3]]))
223193
torch.testing.assert_close(result["labels"], torch.tensor([[1, 2, 3, 4]]))
224194

225195
def test_different_pad_token_id(self):
@@ -229,9 +199,9 @@ def test_different_pad_token_id(self):
229199

230200
result = collator(examples)
231201

202+
self.assertEqual(set(result.keys()), {"input_ids", "attention_mask", "labels"})
232203
torch.testing.assert_close(result["input_ids"], torch.tensor([[1, 2, 3], [4, 5, 999]]))
233204
torch.testing.assert_close(result["attention_mask"], torch.tensor([[1, 1, 1], [1, 1, 0]]))
234-
torch.testing.assert_close(result["position_ids"], torch.tensor([[0, 1, 2], [0, 1, 0]]))
235205
torch.testing.assert_close(result["labels"], torch.tensor([[1, 2, 3], [4, 5, -100]]))
236206

237207
def test_assistant_masks(self):
@@ -246,7 +216,6 @@ def test_assistant_masks(self):
246216

247217
torch.testing.assert_close(result["input_ids"], torch.tensor([[1, 2, 3], [4, 5, 0]]))
248218
torch.testing.assert_close(result["attention_mask"], torch.tensor([[1, 1, 1], [1, 1, 0]]))
249-
torch.testing.assert_close(result["position_ids"], torch.tensor([[0, 1, 2], [0, 1, 0]]))
250219
torch.testing.assert_close(result["labels"], torch.tensor([[-100, 2, 3], [-100, 5, -100]]))
251220

252221
def test_single_example_single_doc(self):

0 commit comments

Comments
 (0)